• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_BASE_TENSOR_H_
18 #define MINDSPORE_CORE_IR_BASE_TENSOR_H_
19 
20 #include <future>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <numeric>
25 #include <mutex>
26 #include <algorithm>
27 #include <condition_variable>
28 #include <utility>
29 #include <iomanip>
30 #include "ir/device_sync.h"
31 #include "ir/meta_tensor.h"
32 #include "utils/log_adapter.h"
33 #include "base/float16.h"
34 #include "base/bfloat16.h"
35 #include "utils/shape_utils.h"
36 #include "utils/ms_exception.h"
37 #include "ir/device_event.h"
38 #include "utils/os.h"
39 #include "ir/meta_grad_data.h"
40 #include "ir/tensor_data.h"
41 #include "utils/ms_utils_secure.h"
42 #include "mindspore/core/base/complex_storage.h"
43 #include "utils/temp_file_manager.h"
44 #include "utils/system/env.h"
45 
46 // brief mindspore namespace.
47 //
48 // mindspore namespace is the top level namespace of MindSpore project.
49 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
50 namespace mindspore {
51 // brief mindspore::tensor namespace
52 enum TensorSyncStatus {
53   kNoNeedSync,
54   kNeedSyncHostToDevice,
55   kNeedSyncHostToDeviceImmediately,
56   kNeedSyncDeviceToHost,
57   kNeedSyncDeviceToHostImmediately
58 };
59 
60 enum TensorCompressionType {
61   kNoCompression = 0,
62   kIndexing = 1,
63   kSparse = 2,
64   kFSE = 3,
65   kBitPacking = 4,
66   kFSEInt = 5,
67   kFSEInfer = 6
68 };
69 
70 // A sub namespace in ME to support tensor related definition.
71 namespace tensor {
72 constexpr auto kEllipsis = "...";
73 constexpr auto kThreshold = 6;
74 constexpr auto kThreshold1D = 1000;
75 
76 constexpr auto kThreshold1DFloat = kThreshold * 2;
77 constexpr auto kThreshold1DInt = kThreshold * 4;
78 constexpr auto kThreshold1DBool = kThreshold * 2;
79 template <typename T, typename U>
NewData(const U * input,size_t size)80 std::unique_ptr<T[]> NewData(const U *input, size_t size) {
81   if (input == nullptr || size == 0) {
82     return nullptr;
83   }
84   if (size > INT32_MAX) {
85     MS_LOG(WARNING) << "Try to alloca a large memory, size is:" << size * sizeof(T);
86   }
87 
88   auto data = std::make_unique<T[]>(size);
89   if constexpr (!std::is_same<T, U>::value &&
90                 (std::is_same<T, float16>::value || std::is_same<U, float16>::value ||
91                  std::is_same<T, bfloat16>::value || std::is_same<U, bfloat16>::value ||
92                  std::is_same<T, ComplexStorage<float>>::value || std::is_same<U, ComplexStorage<float>>::value ||
93                  std::is_same<T, ComplexStorage<double>>::value || std::is_same<U, ComplexStorage<double>>::value)) {
94     // Because float16 and bfloat16 do not support implicit cast from/to other types,
95     // We can not use std::copy() on array of float16 and bfloat16, use a loop here.
96     for (size_t i = 0; i < size; ++i) {
97       data[i] = static_cast<T>(input[i]);
98     }
99   } else {
100     // otherwise, use std::copy for better performance.
101     std::copy(input, input + size, data.get());
102   }
103   return data;
104 }
105 
106 template <typename T, typename Scalar>
NewData(Scalar scalar)107 std::unique_ptr<T[]> NewData(Scalar scalar) {
108   auto data = std::make_unique<T[]>(1);
109   data[0] = static_cast<T>(scalar);
110   return data;
111 }
112 
113 template <typename T>
CopyData(const ShapeVector & shape,void * const data,TypeId data_type)114 std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId data_type) {
115   const size_t size = SizeOf(shape);
116   switch (data_type) {
117     case kNumberTypeBool: {
118       auto buf = static_cast<bool *>(data);
119       return NewData<T>(buf, size);
120     }
121     case kNumberTypeUInt8: {
122       auto buf = static_cast<uint8_t *>(data);
123       return NewData<T>(buf, size);
124     }
125     case kNumberTypeInt4: {
126       auto buf = static_cast<int8_t *>(data);
127       return NewData<T>(buf, size);
128     }
129     case kNumberTypeInt8: {
130       auto buf = static_cast<int8_t *>(data);
131       return NewData<T>(buf, size);
132     }
133     case kNumberTypeInt16: {
134       auto buf = static_cast<int16_t *>(data);
135       return NewData<T>(buf, size);
136     }
137     case kNumberTypeInt32: {
138       auto buf = static_cast<int32_t *>(data);
139       return NewData<T>(buf, size);
140     }
141     case kNumberTypeInt64: {
142       auto buf = static_cast<int64_t *>(data);
143       return NewData<T>(buf, size);
144     }
145     case kNumberTypeUInt16: {
146       auto buf = static_cast<uint16_t *>(data);
147       return NewData<T>(buf, size);
148     }
149     case kNumberTypeUInt32: {
150       auto buf = static_cast<uint32_t *>(data);
151       return NewData<T>(buf, size);
152     }
153     case kNumberTypeUInt64: {
154       auto buf = static_cast<uint64_t *>(data);
155       return NewData<T>(buf, size);
156     }
157     case kNumberTypeFloat16: {
158       auto buf = static_cast<float16 *>(data);
159       return NewData<T>(buf, size);
160     }
161     case kNumberTypeFloat32: {
162       auto buf = static_cast<float *>(data);
163       return NewData<T>(buf, size);
164     }
165     case kNumberTypeFloat64: {
166       auto buf = static_cast<double *>(data);
167       return NewData<T>(buf, size);
168     }
169 #ifndef KERNEL_EXECUTOR_ANDROID
170     case kNumberTypeBFloat16: {
171       auto buf = static_cast<bfloat16 *>(data);
172       return NewData<T>(buf, size);
173     }
174 #endif
175     case kNumberTypeComplex64: {
176       auto buf = static_cast<ComplexStorage<float> *>(data);
177       return NewData<T>(buf, size);
178     }
179     case kNumberTypeComplex128: {
180       auto buf = static_cast<ComplexStorage<double> *>(data);
181       return NewData<T>(buf, size);
182     }
183     case kObjectTypeString: {
184       auto buf = static_cast<uint8_t *>(data);
185       return NewData<T>(buf, size);
186     }
187     default:
188       break;
189   }
190   MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
191 }
192 
193 template <typename T>
CopyData(const ShapeVector & shape,void * const data,size_t data_len)194 std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
195   size_t size = SizeOf(shape);
196   if (size * sizeof(T) != data_len) {
197     MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
198                       << " item size " << sizeof(T);
199   }
200   auto buf = static_cast<T *>(data);
201   return NewData<T>(buf, size);
202 }
203 
204 // TensorStringifier provide methods to convert tensor data to its string representation.
205 template <typename T>
206 class TensorStringifier {
207  public:
TensorStringifier(const T * data,size_t data_size,size_t ndim)208   TensorStringifier(const T *data, size_t data_size, size_t ndim) : data_(data), data_size_(data_size), ndim_(ndim) {}
209   ~TensorStringifier() = default;
210 
ToString(TypeId,const ShapeVector & shape,bool use_comma)211   std::string ToString(TypeId, const ShapeVector &shape, bool use_comma) const {
212     constexpr auto valid =
213       std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
214       std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
215       std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
216       std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value ||
217       std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value ||
218       std::is_same<T, bfloat16>::value;
219     static_assert(valid, "Type is invalid");
220     if (data_size_ == 0) {
221       return "";
222     }
223     if (data_ == nullptr) {
224       return "<uninitialized>";
225     }
226 
227     std::ostringstream ss;
228     if (data_size_ == 1 && ndim_ == 0) {  // Scalar
229       int max = 0;
230       OutputDataString(ss, 0, 0, 1, false, &max);
231       return ss.str();
232     }
233 
234     int num_width = 0;
235     ssize_t cursor = 0;
236     SummaryStringRecursive(ss, shape, &cursor, 0, use_comma, &num_width);
237     return ProcessPlaceholder(ss, num_width);
238   }
239 
240  private:
OutputFloatDataString(std::ostringstream & ss,bool isScalar,const T & value)241   static void OutputFloatDataString(std::ostringstream &ss, bool isScalar, const T &value) {
242     if (isScalar) {
243       ss << value;
244     } else {
245       // The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
246       const int width = std::is_same<T, float16>::value ? 11 : 15;
247       // The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
248       const int precision = std::is_same<T, float16>::value ? 4 : 8;
249       ss << std::setw(width) << std::setprecision(precision) << std::setiosflags(std::ios::scientific | std::ios::right)
250          << value;
251     }
252   }
253 
OutputBoolDataString(std::ostringstream & ss,bool isScalar,const T & value)254   static void OutputBoolDataString(std::ostringstream &ss, bool isScalar, const T &value) {
255     if (isScalar) {
256       ss << (value ? "True" : "False");
257     } else {
258       constexpr int bool_max_width = sizeof("False") - 1;
259       ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
260     }
261   }
262 
OutputOtherDataString(std::ostringstream & ss,bool isScalar,const T & value,int * max_width)263   static void OutputOtherDataString(std::ostringstream &ss, bool isScalar, const T &value, int *max_width) {
264     std::ostringstream value_ss;
265     if constexpr (std::is_same<T, uint8_t>::value) {
266       value_ss << static_cast<uint16_t>(value);
267     } else if constexpr (std::is_same<T, int8_t>::value) {
268       value_ss << static_cast<int16_t>(value);
269     } else {
270       value_ss << value;
271     }
272     auto value_str = value_ss.str();
273     if (!isScalar) {
274       const int width = static_cast<int>(value_str.size());
275       *max_width = std::max(*max_width, width);
276       // Add a padding string before the number, such as "###123", for subsequent replacement.
277       std::string pad(width, '#');
278       ss << pad;
279     }
280     ss << value_str;
281   }
282 
ProcessPlaceholder(const std::ostringstream & ss,int max_width)283   static std::string ProcessPlaceholder(const std::ostringstream &ss, int max_width) {
284     std::string str = ss.str();
285     if constexpr (std::is_same<T, bool>::value || std::is_same<T, float16>::value || std::is_same<T, float>::value ||
286                   std::is_same<T, double>::value) {
287       return str;
288     }
289     // Replace # with placeholder.
290     size_t index = str.find('#');
291     while (index != std::string::npos) {
292       size_t pos = index;
293       while (str[pos] == '#') {
294         pos++;
295       }
296       size_t len = pos - index;
297       std::string space(max_width - SizeToInt(len), ' ');
298       str = str.replace(index, len, space);
299       index = str.find('#', index);
300     }
301     return str;
302   }
303 
OutputDataString(std::ostringstream & ss,ssize_t cursor,ssize_t start,ssize_t end,bool use_comma,int * max_width)304   void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma,
305                         int *max_width) const {
306     const bool isScalar = ndim_ == 0 && end - start == 1;
307     constexpr auto isBool = std::is_same<T, bool>::value;
308     constexpr auto isFloat =
309       std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
310     constexpr auto isComplex =
311       std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value;
312     constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
313     for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
314       const auto value = data_[cursor + i];
315       if constexpr (isComplex) {
316         ss << value;
317       } else if constexpr (isFloat) {
318         OutputFloatDataString(ss, isScalar, value);
319       } else if (isBool) {
320         OutputBoolDataString(ss, isScalar, value);
321       } else {
322         OutputOtherDataString(ss, isScalar, value, max_width);
323       }
324       if (!isScalar && i != end - 1) {
325         if (use_comma) {
326           ss << ',';
327         }
328         ss << ' ';
329       }
330       if (!isScalar && ndim_ == 1 && end - start > (kThreshold >> 1) && (i + 1) % linefeedThreshold == 0) {
331         // Add a line feed every {threshold of type} for 1D tensor.
332         ss << '\n' << ' ';
333       }
334     }
335   }
336 
SummaryStringRecursive(std::ostringstream & ss,const ShapeVector & shape,ssize_t * cursor,ssize_t depth,bool use_comma,int * max_width)337   void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth,
338                               bool use_comma, int *max_width) const {
339     if (depth >= static_cast<ssize_t>(ndim_)) {
340       return;
341     }
342     ss << '[';
343     if (depth == static_cast<ssize_t>(ndim_) - 1) {  // Bottom dimension
344       ssize_t num = shape[depth];
345       if ((num > kThreshold && ndim_ > 1) || (num > kThreshold1D && ndim_ == 1)) {
346         OutputDataString(ss, *cursor, 0, kThreshold >> 1, use_comma, max_width);
347         ss << ' ' << kEllipsis << ' ';
348         OutputDataString(ss, *cursor, num - (kThreshold >> 1), num, use_comma, max_width);
349       } else {
350         OutputDataString(ss, *cursor, 0, num, use_comma, max_width);
351       }
352       *cursor += num;
353     } else {  // Middle dimension
354       ssize_t num = shape[depth];
355       // Handle the first half.
356       for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold >> 1), num); i++) {
357         if (i > 0) {
358           if (use_comma) {
359             ss << ',';
360           }
361           ss << '\n';
362           ss << std::setw(depth + 1) << ' ';  // Add the indent.
363         }
364         SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma, max_width);
365       }
366       // Handle the ignored part.
367       if (num > kThreshold) {
368         if (use_comma) {
369           ss << ',';
370         }
371         ss << '\n';
372         ss << std::setw(depth + 1) << ' ';  // Add the indent.
373         ss << kEllipsis;
374         // Ignored at this layer.
375         ssize_t ignored = shape[depth + 1];
376         const size_t offset = 2;
377         for (ssize_t i = depth + offset; i < static_cast<ssize_t>(ndim_); i++) {
378           ignored *= shape[i];
379         }
380         // Multiple with ignored layers number.
381         ignored *= (num - kThreshold);
382         *cursor += ignored;
383       }
384       // Handle the second half.
385       if (num > (kThreshold >> 1)) {
386         ssize_t iter_times =
387           std::min(static_cast<ssize_t>(num - (kThreshold >> 1)), static_cast<ssize_t>(kThreshold >> 1));
388         for (ssize_t i = 0; i < iter_times; i++) {
389           if (use_comma && (i != 0 || num <= kThreshold)) {  // Not just after ignored part || Not handle ignored part
390             ss << ',';
391           }
392           ss << '\n';
393           ss << std::setw(depth + 1) << ' ';  // Add the indent.
394           SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma, max_width);
395         }
396       }
397     }
398     ss << ']';
399   }
400 
401   const T *data_;
402   const size_t data_size_;
403   const size_t ndim_;
404 };
405 // Tensor data implementation.
406 template <typename T>
407 class TensorDataImpl : public TensorData {
408  public:
TensorDataImpl(const ShapeVector & shape)409   explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
~TensorDataImpl()410   ~TensorDataImpl() override {
411     try {
412       RemoveOffloadFile();
413     } catch (const std::exception &e) {
414       MS_LOG(ERROR) << "Exception occurred when cleaning tensor. Error info " << e.what();
415     } catch (...) {
416       MS_LOG(ERROR) << "Exception occurred when cleaning tensor.";
417     }
418   }
419 
TensorDataImpl(const ShapeVector & shape,void * data,size_t data_len)420   TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
421       : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
422 
TensorDataImpl(const ShapeVector & shape,void * data,TypeId data_type)423   TensorDataImpl(const ShapeVector &shape, void *data, TypeId data_type)
424       : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
425 
426   template <typename U>
TensorDataImpl(const ShapeVector & shape,const U * input,size_t size)427   TensorDataImpl(const ShapeVector &shape, const U *input, size_t size)
428       : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(input, size)) {}
429 
430   template <typename Scalar>
TensorDataImpl(const ShapeVector & shape,Scalar scalar)431   TensorDataImpl(const ShapeVector &shape, Scalar scalar)
432       : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(scalar)) {}
433 
size()434   ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
435 
itemsize()436   ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
437 
nbytes()438   ssize_t nbytes() const override { return size() * itemsize(); }
439 
ndim()440   ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
441 
is_sub_data()442   bool is_sub_data() const override { return false; }
443 
has_sub_data()444   bool has_sub_data() const override { return false; }
445 
data()446   void *data() override {
447     if (data_ != nullptr) {
448       return data_.get();
449     }
450 
451     if (data_size_ > INT32_MAX) {
452       MS_LOG(WARNING) << "Try to alloca a large memory, size is:" << data_size_ * sizeof(T);
453     }
454     // Lazy allocation.
455     data_ = std::make_unique<T[]>(data_size_);
456 
457     // Load data from file
458     if (!file_path_.empty()) {
459       auto fs = mindspore::system::Env::GetFileSystem();
460       MS_EXCEPTION_IF_NULL(fs);
461       if (fs->FileExist(file_path_)) {
462         auto file = fs->CreateWriteFile(file_path_, "r+");
463         MS_EXCEPTION_IF_NULL(file);
464         bool success = file->PRead(data_.get(), data_size_ * sizeof(T), 0);
465         if (!success) {
466           MS_LOG(WARNING) << "Tensor load data from file: " << file_path_ << " failed!";
467         }
468         if (!file->Close()) {
469           MS_LOG(WARNING) << "Close tensor file: " << file_path_ << " failed!";
470         }
471       } else {
472         MS_LOG(WARNING) << "Invalid tensor file path: " << file_path_;
473       }
474     }
475     return data_.get();
476   }
477 
set_file_path(const std::string & file_path)478   void set_file_path(const std::string &file_path) override { file_path_ = file_path; }
479 
file_path()480   const std::string file_path() const override { return file_path_; }
481 
const_data()482   const void *const_data() const override {
483     // May return nullptr if data not initialized.
484     return data_.get();
485   }
486 
equals(const TensorDataImpl<T> & other)487   virtual bool equals(const TensorDataImpl<T> &other) const {
488     auto ptr = &other;
489     if (ptr == this) {
490       return true;
491     }
492     if (data_ == nullptr || ptr->data_ == nullptr) {
493       return false;
494     }
495     return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
496            std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
497   }
498 
equals(const TensorData & other)499   bool equals(const TensorData &other) const override {
500     // Not same type, compare data byte by byte.
501     return TensorData::equals(other);
502   }
503 
ToString(TypeId type,const ShapeVector & shape,bool use_comma)504   std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override {
505     TensorStringifier<T> stringifier{data_.get(), data_size_, ndim_};
506     return stringifier.ToString(type, shape, use_comma);
507   }
508 
509  private:
RemoveOffloadFile()510   void RemoveOffloadFile() {
511     if (!file_path_.empty()) {
512       TempFileManager::GetInstance().RemoveFile(file_path_);
513       TempFileManager::GetInstance().UnRegister(file_path_);
514       file_path_ = "";
515     }
516   }
517 
518   size_t ndim_{0};
519   size_t data_size_{0};
520   std::unique_ptr<T[]> data_;
521   std::string file_path_{""};
522 };
523 template <template <class> class ImplClass = TensorDataImpl, typename... Args>
MakeTensorData(TypeId data_type,Args &&...args)524 TensorDataPtr MakeTensorData(TypeId data_type, Args &&... args) {
525   switch (data_type) {
526     case kNumberTypeBool:
527       return std::make_shared<ImplClass<bool>>(std::forward<Args>(args)...);
528     case kNumberTypeUInt8:
529       return std::make_shared<ImplClass<uint8_t>>(std::forward<Args>(args)...);
530     case kNumberTypeInt4:
531       return std::make_shared<ImplClass<int8_t>>(std::forward<Args>(args)...);
532     case kNumberTypeInt8:
533       return std::make_shared<ImplClass<int8_t>>(std::forward<Args>(args)...);
534     case kNumberTypeInt16:
535       return std::make_shared<ImplClass<int16_t>>(std::forward<Args>(args)...);
536     case kNumberTypeInt:
537     case kNumberTypeInt32:
538       return std::make_shared<ImplClass<int32_t>>(std::forward<Args>(args)...);
539     case kNumberTypeInt64:
540       return std::make_shared<ImplClass<int64_t>>(std::forward<Args>(args)...);
541     case kNumberTypeUInt16:
542       return std::make_shared<ImplClass<uint16_t>>(std::forward<Args>(args)...);
543     case kNumberTypeUInt32:
544       return std::make_shared<ImplClass<uint32_t>>(std::forward<Args>(args)...);
545     case kNumberTypeUInt64:
546       return std::make_shared<ImplClass<uint64_t>>(std::forward<Args>(args)...);
547     case kNumberTypeFloat16:
548       return std::make_shared<ImplClass<float16>>(std::forward<Args>(args)...);
549     case kNumberTypeFloat:
550       return std::make_shared<ImplClass<float>>(std::forward<Args>(args)...);
551     case kNumberTypeFloat32:
552       return std::make_shared<ImplClass<float>>(std::forward<Args>(args)...);
553     case kNumberTypeFloat64:
554       return std::make_shared<ImplClass<double>>(std::forward<Args>(args)...);
555 #ifndef KERNEL_EXECUTOR_ANDROID
556     case kNumberTypeBFloat16:
557       return std::make_shared<ImplClass<bfloat16>>(std::forward<Args>(args)...);
558 #endif
559     case kNumberTypeComplex64:
560       return std::make_shared<ImplClass<ComplexStorage<float>>>(std::forward<Args>(args)...);
561     case kNumberTypeComplex128:
562       return std::make_shared<ImplClass<ComplexStorage<double>>>(std::forward<Args>(args)...);
563     case kObjectTypeString:
564       return std::make_shared<ImplClass<uint8_t>>(std::forward<Args>(args)...);
565     case kObjectTypeTensorType:
566     case kObjectTypeMapTensorType:
567       return std::make_shared<ImplClass<int>>(std::forward<Args>(args)...);
568     default:
569       break;
570   }
571   MS_LOG(ERROR) << "Cannot construct Tensor because of unsupported data type: " << TypeIdToString(data_type) << ".";
572   return nullptr;
573 }
574 class BaseTensor;
575 using BaseTensorPtr = std::shared_ptr<BaseTensor>;
576 using BaseTensorPtrList = std::vector<std::shared_ptr<BaseTensor>>;
577 
578 // BaseTensor entity class
579 class MS_CORE_API BaseTensor : public MetaTensor {
580  public:
581   BaseTensor() = default;
582 
583   /// \brief Create base tensor from another base tensor, data is shared.
584   ///
585   /// \param[in] tensor [BaseTensor] The input base tensor.
586   explicit BaseTensor(const BaseTensor &tensor);
587 
588   /// \brief Create base tensor with given data type from another tensor.
589   ///
590   /// \param[in] tensor [BaseTensor] The input tensor.
591   /// \param[in] data_type [TypeId] The new tensor data type.
592   BaseTensor(const BaseTensor &tensor, TypeId data_type);
593 
594   /// \brief Create base tensor with the given shared tensor data.
595   ///
596   /// \param[in] data_type [TypeId] Data type of the tensor.
597   /// \param[in] shape The shape represented by ShapeVector of the tensor.
598   /// \param[in] data The shared tensor data.
599   BaseTensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data);
600 
601   /// \brief Create a lazy allocated tensor.
602   ///
603   /// \param[in] data_type [TypeId] Data type of the tensor.
604   /// \param[in] shape The shape represented by ShapeVector of the tensor.
605   BaseTensor(TypeId data_type, const ShapeVector &shape);
606 
607   /// \brief Create a tensor with input data buffer.
608   ///
609   /// \param[in] data_type [TypeId] Data type of the tensor.
610   /// \param[in] shape The shape represented by ShapeVector of the tensor.
611   /// \param[in] data The input data to be copied into tensor.
612   /// \param[in] data_len The length of data in bytes.
613   BaseTensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);
614 
615   /// \brief Create a tensor with input data buffer and given source data type.
616   ///
617   /// \param[in] data_type [TypeId] Data type of the tensor.
618   /// \param[in] shape The shape represented by ShapeVector of the tensor.
619   /// \param[in] data The input data to be copied into tensor.
620   /// \param[in] src_data_type The source data type.
621   BaseTensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type);
622 
623   /// \brief Create 1 dimension tensor from an int vector.
624   ///
625   /// \param[in] input [std::vector<int64_t>] the data for tensor.
626   /// \param[in] data_type [TypeId] data type.
627   explicit BaseTensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr);
628 
629   /// \brief Create 1 dimension tensor from an int vector.
630   ///
631   /// \param[in] input [std::vector<int32_t>] the data for tensor.
632   /// \param[in] data_type [TypeId] data type.
633   explicit BaseTensor(const std::vector<int32_t> &input, const TypePtr &data_type = nullptr);
634 
635   /// \brief Create 1 dimension tensor from a float vector.
636   ///
637   /// \param[in] input [std::vector<double>] the data for tensor.
638   /// \param[in] data_type [TypeId] data type.
639   explicit BaseTensor(const std::vector<double> &input, const TypePtr &data_type = nullptr);
640 
641   /// \brief Create 1 dimension tensor from a float vector.
642   ///
643   /// \param[in] input [std::vector<float>] the data for tensor.
644   /// \param[in] data_type [TypeId] data type.
645   explicit BaseTensor(const std::vector<float> &input, const TypePtr &data_type = nullptr);
646 
647   /// \brief Create 0 dimension tensor from an int64_t scalar.
648   ///
649   /// \param[in] input [int64] the data for tensor.
650   /// \param[in] data_type [TypeId] data type.
651   explicit BaseTensor(int64_t input, const TypePtr &data_type = nullptr);
652 
653   /// \brief Create 0 dimension tensor from an int32_t scalar.
654   ///
655   /// \param[in] input [int32] the data for tensor.
656   /// \param[in] data_type [TypeId] data type.
657   explicit BaseTensor(int32_t input, const TypePtr &data_type = nullptr);
658 
659   /// \brief Create 0 dimension tensor from an int16_t scalar.
660   ///
661   /// \param[in] input [int16] the data for tensor.
662   /// \param[in] data_type [TypeId] data type.
663   explicit BaseTensor(int16_t input, const TypePtr &data_type = nullptr);
664 
665   /// \brief Create 0 dimension tensor from an int8_t scalar.
666   ///
667   /// \param[in] input [int8] the data for tensor.
668   /// \param[in] data_type [TypeId] data type.
669   explicit BaseTensor(int8_t input, const TypePtr &data_type = nullptr);
670 
671   /// \brief Create 0 dimension tensor from a double scalar.
672   ///
673   /// \param[in] input [double] the data for tensor.
674   /// \param[in] data_type [TypeId] data type.
675   explicit BaseTensor(double input, const TypePtr &data_type = nullptr);
676 
677   /// \brief Create 0 dimension tensor from a float scalar.
678   ///
679   /// \param[in] input [float] the data for tensor.
680   /// \param[in] data_type [TypeId] data type.
681   explicit BaseTensor(float input, const TypePtr &data_type = nullptr);
682 
683   /// \brief Create 0 dimension tensor from a float16 scalar.
684   ///
685   /// \param[in] input [float16] the data for tensor.
686   /// \param[in] data_type [TypeId] data type.
687   explicit BaseTensor(float16 input, const TypePtr &data_type = nullptr);
688 
689   /// \brief Create 0 dimension tensor from a bfloat16 scalar.
690   ///
691   /// \param[in] input [bfloat16] the data for tensor.
692   /// \param[in] data_type [TypeId] data type.
693   explicit BaseTensor(bfloat16 input, const TypePtr &data_type = nullptr);
694 
695   /// \brief Create 0 dimension tensor from a uint64 scalar.
696   ///
697   /// \param[in] input [uint64] the data for tensor.
698   /// \param[in] data_type [TypeId] data type.
699   explicit BaseTensor(uint64_t input, const TypePtr &data_type = nullptr);
700 
701   /// \brief Create 0 dimension tensor from a uint32 scalar.
702   ///
703   /// \param[in] input [uint32] the data for tensor.
704   /// \param[in] data_type [TypeId] data type.
705   explicit BaseTensor(uint32_t input, const TypePtr &data_type = nullptr);
706 
707   /// \brief Create 0 dimension tensor from a uint16 scalar.
708   ///
709   /// \param[in] input [uint16] the data for tensor.
710   /// \param[in] data_type [TypeId] data type.
711   explicit BaseTensor(uint16_t input, const TypePtr &data_type = nullptr);
712 
713   /// \brief Create 0 dimension tensor from a uint8 scalar.
714   ///
715   /// \param[in] input [uint8] the data for tensor.
716   /// \param[in] data_type [TypeId] data type.
717   explicit BaseTensor(uint8_t input, const TypePtr &data_type = nullptr);
718 
719   /// \brief Create 0 dimension tensor from a bool scalar.
720   ///
721   /// \param[in] input [bool] the data for tensor.
722   /// \param[in] data_type [TypeId] data type.
723   explicit BaseTensor(bool input, const TypePtr &data_type = nullptr);
724 
725   /// \brief Create a chunk tensor with the given data size.
726   ///
727   /// \param[in] data_type [TypeId] Data type of the tensor.
728   /// \param[in] data_size The tensor chunk data size in number of elements.
729   BaseTensor(TypeId data_type, size_t data_size);
730 
731   /// \brief Create a Tensor which shape and size may be inconsistent, such as Tensor with compression data.
732   ///
733   /// \param[in] origin_data_type [TypeId] Data type of the origin tensor.
734   /// \param[in] shape The shape represented by ShapeVector of the tensor.
735   /// \param[in] compression_data_size The compression data buffer size.
736   /// \param[in] TensorCompressionType The tensor compression type.
737   BaseTensor(TypeId origin_data_type, const ShapeVector &shape, size_t compression_data_size,
738              TensorCompressionType compression_type);
739 
740   BaseTensor &operator=(const BaseTensor &tensor);
741 
742   /// Destructor of BaseTensor.
743   ~BaseTensor() override = default;
744 
745   MS_DECLARE_PARENT(BaseTensor, MetaTensor);
746 
747   /// \brief Assign value to this tensor.
748   ///
749   /// \param[in] tensor The input tensor.
750   /// \return Tensor with new value.
751   BaseTensor &AssignValue(const BaseTensor &tensor);
752 
753   /// \brief Compare two tensor objects to see if they have same data type, shape and data address.
754   ///
755   /// \param[in] tensor The BaseTensor object to be compared.
756   /// \return True if having same type, shape and data address, otherwise false.
757   bool operator==(const BaseTensor &tensor) const;
758 
759   /// \brief Create Abstract for Tensor.
760   ///
761   /// \return Abstract of Tensor.
762   abstract::AbstractBasePtr ToAbstract() override;
763 
764   /// \brief Get Abstract cache. The value of the abstract is null.
765   /// Only used by InferShape in PyNative mode.
766   ///
767   /// \return Abstract of tensor.
768   abstract::AbstractBasePtr GetAbstractCache();
769 
770   /// \brief It is different from 'operator==' which just compares shape/type/address,
771   /// it does real value comparison.
772   ///
773   /// \param[in] tensor The BaseTensor object to be compared.
774   /// \return True if it has the same value, otherwise false.
775   bool ValueEqual(const BaseTensor &tensor) const;
776 
777   bool operator==(const Value &other) const override {
778     if (other.isa<BaseTensor>()) {
779       auto &other_ = static_cast<const BaseTensor &>(other);
780       return *this == other_;
781     }
782     return false;
783   }
784 
785   /// \brief Gets tensor's dimension.
786   ///
787   /// \return The number of dimensions of the tensor data.
DataDim()788   int DataDim() const { return static_cast<int>(data().ndim()); }
789 
790   /// \brief Getting tensor data size.
791   ///
792   /// \return The total number of elements of the tensor data.
DataSize()793   size_t DataSize() const { return data().size(); }
794 
795   /// \brief Get the data type of the tensor for C++
796   ///
797   /// \return [int] The tensor's data type will be cast to int to return.
data_type_c()798   int data_type_c() const { return static_cast<int>(data_type_); }
799 
800   /// \brief Get the tensor's shape for C++
801   ///
802   /// \return [ShapeVector]
shape_c(void)803   ShapeVector shape_c(void) const { return shape(); }
804 
805   /// \brief Get BaseTensor data pointer for c++ type
806   ///
807   /// \return The pointer to the object
data_c()808   void *data_c() { return data().data(); }
809 
810   /// \brief Get BaseTensor data byte-size for c++ type
811   ///
812   /// \return byte size of BaseTensor data
Size()813   size_t Size() const { return static_cast<size_t>(data().nbytes()); }
814 
815   /// \brief The pointer to the object
data_c()816   void *data_c() const { return data_->data(); }
817 
818   /// \brief To synchronize data with the device, you need to wait for the data to be valid.
819   ///
820   void data_sync(bool need_wait = true) const;
821 
822   /// \brief Get the internal data object.
823   ///
824   /// \return The reference to internal data object.
data()825   TensorData &data() {
826     MS_EXCEPTION_IF_NULL(data_);
827     return *data_;
828   }
829 
830   /// \brief Get the internal data shared pointer.
831   ///
832   /// return The reference to internal data object.
data_ptr()833   const TensorDataPtr &data_ptr() const { return data_; }
834 
835   /// \brief Get the internal data object.
836   ///
837   /// \return The reference to internal data object.
data()838   const TensorData &data() const { return *data_; }
839 
set_data(const TensorDataPtr & data)840   void set_data(const TensorDataPtr &data) { data_ = data; }
841 
842   TypeId set_data_type(TypeId data_type) override;
843 
844   size_t set_shape(const ShapeVector &shape) override;
845 
846   /// \brief Get information about shape and data type.
847   ///
848   /// \return Information about shape and data type.
849   std::string GetShapeAndDataTypeInfo() const;
850 
851   /// \brief Get display information of limit size.
852   ///
853   /// \param[in] limit_size The limit size.
854   /// \return The display information of limit size.
855   std::string ToStringInternal(size_t limit_size) const;
856 
857   /// \brief Get display information with unlimited size.
858   ///
859   /// \return The display information with unlimited size.
860   std::string ToStringNoLimit() const;
861 
862   /// \brief Get display information of this BaseTensor.
863   ///
864   /// \return The display information of this BaseTensor.
865   std::string ToString() const override;
866 
867   /// \brief Get display information in repr form.
868   ///
869   /// \return The display information in repr form.
870   std::string ToStringRepr() const;
871 
872   /// \brief Check if this BaseTensor is forward output.
873   ///
874   /// \return Whether this BaseTensor is forward output.
is_forward_output()875   bool is_forward_output() const { return is_forward_output_; }
876 
877   /// \brief Set the forward output flag of this BaseTensor.
878   ///
879   /// \param[in] is_forward_output Whether this BaseTensor is forward output.
set_is_forward_output(bool is_forward_output)880   void set_is_forward_output(bool is_forward_output) { is_forward_output_ = is_forward_output; }
881 
882   /// \brief Get the device address.
883   ///
884   /// \return The device address.
885   DeviceSyncPtr device_address() const;
886 
887   /// \brief Set the device address.
888   ///
889   /// \param[in] device_sync The input Device synchronization.
890   /// \param[in] need_update_ref_count If need_update_ref_count is true, the device address cannot be released and
891   /// reused, so the feature map should set false when set device address of tensor.
892   void set_device_address(const DeviceSyncPtr &device_sync, bool need_update_ref_count = true);
893 
894   /// \brief Get the id of this BaseTensor.
895   ///
896   /// \return The id of this BaseTensor.
id()897   std::string id() const { return id_; }
898 
899   /// \brief Set lazy callback function to this Tensor
900   ///
901   /// \param[in] lazy_callback Wait for async tasks finish before data_sync.
RegisterLazyCallback(const std::function<void (void)> & lazy_callback)902   static void RegisterLazyCallback(const std::function<void(void)> &lazy_callback) { lazy_callback_ = lazy_callback; }
903 
904   /// \brief Set contiguous callback function to this BaseTensor
905   ///
906   /// \param[in] contiguous_callback The callback from backend when need to make tensor contiguous.
set_contiguous_callback(const std::function<DeviceSyncPtr (const DeviceSyncPtr &)> & contiguous_callback)907   void set_contiguous_callback(const std::function<DeviceSyncPtr(const DeviceSyncPtr &)> &contiguous_callback) {
908     contiguous_callback_ = contiguous_callback;
909   }
910 
911   /// @brief Get Pynative auto_grad meta data.
912   /// @return Auto grad meta data
auto_grad_meta_data()913   const AutoGradMetaDataPtr &auto_grad_meta_data() const { return auto_grad_meta_data_; }
914 
915   /// @brief Set Pynative auto_grad meta data.
916   /// @param auto_grad_meta_data
set_auto_grad_meta_data(const AutoGradMetaDataPtr & auto_grad_meta_data)917   void set_auto_grad_meta_data(const AutoGradMetaDataPtr &auto_grad_meta_data) {
918     auto_grad_meta_data_ = auto_grad_meta_data;
919   }
920 
921   /// \brief Get tensor storage info.
922   ///
923   /// \return BaseTensor storage info, the value is nullptr default.
924   const TensorStorageInfoPtr storage_info() const;
925 
926   /// \brief Set tensor abstract.
927   ///
928   /// \param[in] abstract The abstract of tensor.
set_abstract(const std::weak_ptr<abstract::AbstractBase> & abstract)929   void set_abstract(const std::weak_ptr<abstract::AbstractBase> &abstract) { abstract_ = abstract; }
930 
931   /// \brief Set synchronization status.
932   ///
933   /// \param[in] sync_status The input synchronization status.
set_sync_status(TensorSyncStatus sync_status)934   void set_sync_status(TensorSyncStatus sync_status) const { sync_status_ = sync_status; }
935 
936   /// \brief Get synchronization status.
937   ///
938   /// \return The synchronization status.
sync_status()939   TensorSyncStatus sync_status() const { return sync_status_; }
940 
941   /// \brief Check the value of sync_status_.
942   ///
943   /// \return Ture if sync_status_ is kNeedSyncDeviceToHostImmediately.
NeedSyncDeviceToHostImmediately()944   bool NeedSyncDeviceToHostImmediately() const { return sync_status_ == kNeedSyncDeviceToHostImmediately; }
945 
946   /// \brief Check the value of sync_status_.
947   ///
948   /// \return Ture if sync_status_ is kNeedSyncDeviceToHost.
NeedSyncDeviceToHost()949   bool NeedSyncDeviceToHost() const { return sync_status_ == kNeedSyncDeviceToHost; }
950 
951   /// \brief Check the value of sync_status_.
952   ///
953   /// \return Ture if sync_status_ is kNeedSyncHostToDevice.
NeedSyncHostToDevice()954   bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
955 
956   /// \brief Check the value of sync_status_.
957   ///
958   /// \return Ture if sync_status_ is kNeedSyncHostToDeviceImmediately.
NeedSyncHostToDeviceImmediately()959   bool NeedSyncHostToDeviceImmediately() const { return sync_status_ == kNeedSyncHostToDeviceImmediately; }
960 
961   /// \brief Get tensor's BaseShape.
962   ///
963   /// \return The BaseShape of this tensor.
base_shape_ptr()964   const BaseShapePtr &base_shape_ptr() const { return base_shape_ptr_; }
965 
966   /// \brief Set tensor's BaseShape.
967   ///
968   /// \param[in] BaseShapePtr The tensor's BaseShape.
set_base_shape(const BaseShapePtr & base_shape)969   void set_base_shape(const BaseShapePtr &base_shape) { base_shape_ptr_ = base_shape; }
970 
971   /// \brief Determines whether the memory of tensor is contiguous.
972   ///
973   /// \return True if tensor memory is contiguous, false otherwise.
974   bool is_contiguous() const;
975 
976   /// \brief Get tensor storage stride.
977   ///
978   /// \return storage stride.
979   std::vector<int64_t> stride() const;
980 
981   /// \brief Get tensor storage offset.
982   ///
983   /// \return storage offset.
984   const int64_t storage_offset() const;
985 
set_need_pipeline_sync(bool need_pipeline_sync)986   void set_need_pipeline_sync(bool need_pipeline_sync) { need_pipeline_sync_ = need_pipeline_sync; }
987 
988   /// \brief Execute lazy task.
989   ///
990   void ExecuteLazyTask() const;
991 
992  protected:
993   bool is_forward_output_{false};
994   bool need_pipeline_sync_{false};
995   std::string id_{""};
996   mutable DeviceSyncPtr device_sync_{nullptr};
997   mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
998   AutoGradMetaDataPtr auto_grad_meta_data_{nullptr};
999   std::weak_ptr<abstract::AbstractBase> abstract_;
1000   TensorDataPtr data_{nullptr};
1001   // Tensor base shape which contain dynamic shape info.
1002   BaseShapePtr base_shape_ptr_{nullptr};
1003   inline static std::function<void(void)> lazy_callback_{nullptr};
1004   std::function<DeviceSyncPtr(const DeviceSyncPtr &)> contiguous_callback_{nullptr};
1005 };
1006 
1007 // Convert shape vector to string.
1008 MS_CORE_API std::string ShapeToString(const ShapeVector &shape);
1009 
CopyTensorData(const TensorDataPtr & dest,const TensorDataPtr & src)1010 inline static void CopyTensorData(const TensorDataPtr &dest, const TensorDataPtr &src) {
1011   auto dest_bytes = dest->nbytes();
1012   auto src_bytes = src->nbytes();
1013   auto err = common::huge_memcpy(static_cast<uint8_t *>(dest->data()), dest_bytes,
1014                                  static_cast<const uint8_t *>(src->const_data()), src_bytes);
1015   if (err != EOK) {
1016     MS_LOG(INTERNAL_EXCEPTION) << "Copy tensor data failed! bytes: " << dest_bytes << "/" << src_bytes << ".";
1017   }
1018 }
1019 }  // namespace tensor
1020 }  // namespace mindspore
1021 
1022 #endif  // MINDSPORE_CORE_IR_BASE_TENSOR_H_
1023