1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/tensor.h"
18
19 #include <atomic>
20 #include <functional>
21 #include <numeric>
22 #include <vector>
23 #include <memory>
24 #include <sstream>
25 #include <string>
26 #include <utility>
27 #include <iomanip>
28 #include <algorithm>
29 #include <type_traits>
30 #include <typeinfo>
31
32 #include "abstract/utils.h"
33 #include "abstract/abstract_value.h"
34 #include "base/complex_storage.h"
35
36 namespace mindspore {
37 namespace tensor {
38 constexpr auto kEllipsis = "...";
39 constexpr auto kThreshold = 6;
40
41 constexpr auto kThreshold1DFloat = kThreshold * 2;
42 constexpr auto kThreshold1DInt = kThreshold * 4;
43 constexpr auto kThreshold1DBool = kThreshold * 2;
44
MakeId()45 static std::string MakeId() {
46 // Use atomic to make id generator thread safe.
47 static std::atomic<uint64_t> last_id{1};
48 return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
49 }
50
TypeIdOf(const TypePtr & data_type,TypeId defaultTypeId)51 static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
52 return data_type ? data_type->type_id() : defaultTypeId;
53 }
54
SizeOf(const ShapeVector & shape)55 static size_t SizeOf(const ShapeVector &shape) {
56 return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
57 }
58
ShapeToString(const ShapeVector & shape)59 static std::string ShapeToString(const ShapeVector &shape) {
60 std::string str = "[";
61 const size_t count = shape.size();
62 for (size_t i = 0; i < count; ++i) {
63 if (i > 0) {
64 str.append(", ");
65 }
66 str.append(std::to_string(shape[i]));
67 }
68 return str.append("]");
69 }
70
71 template <typename T, typename U>
NewData(const U * input,size_t size)72 std::unique_ptr<T[]> NewData(const U *input, size_t size) {
73 if (input == nullptr || size == 0) {
74 return nullptr;
75 }
76 auto data = std::make_unique<T[]>(size);
77 if constexpr (!std::is_same<T, U>::value &&
78 (std::is_same<T, float16>::value || std::is_same<U, float16>::value ||
79 std::is_same<T, ComplexStorage<float>>::value || std::is_same<U, ComplexStorage<float>>::value ||
80 std::is_same<T, ComplexStorage<double>>::value || std::is_same<U, ComplexStorage<double>>::value)) {
81 // Because float16 do not support implicit cast from/to other types,
82 // We can not use std::copy() on array of float16, use a loop here.
83 for (size_t i = 0; i < size; ++i) {
84 data[i] = static_cast<T>(input[i]);
85 }
86 } else {
87 // otherwise, use std::copy for better performance.
88 std::copy(input, input + size, data.get());
89 }
90 return data;
91 }
92
93 template <typename T, typename Scalar>
NewData(Scalar scalar)94 std::unique_ptr<T[]> NewData(Scalar scalar) {
95 auto data = std::make_unique<T[]>(1);
96 data[0] = static_cast<T>(scalar);
97 return data;
98 }
99
100 template <typename T>
CopyData(const ShapeVector & shape,void * const data,TypeId data_type)101 std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId data_type) {
102 const size_t size = SizeOf(shape);
103 switch (data_type) {
104 case kNumberTypeBool: {
105 auto buf = static_cast<bool *>(data);
106 return NewData<T>(buf, size);
107 }
108 case kNumberTypeUInt8: {
109 auto buf = static_cast<uint8_t *>(data);
110 return NewData<T>(buf, size);
111 }
112 case kNumberTypeInt8: {
113 auto buf = static_cast<int8_t *>(data);
114 return NewData<T>(buf, size);
115 }
116 case kNumberTypeInt16: {
117 auto buf = static_cast<int16_t *>(data);
118 return NewData<T>(buf, size);
119 }
120 case kNumberTypeInt32: {
121 auto buf = static_cast<int32_t *>(data);
122 return NewData<T>(buf, size);
123 }
124 case kNumberTypeInt64: {
125 auto buf = static_cast<int64_t *>(data);
126 return NewData<T>(buf, size);
127 }
128 case kNumberTypeUInt16: {
129 auto buf = static_cast<uint16_t *>(data);
130 return NewData<T>(buf, size);
131 }
132 case kNumberTypeUInt32: {
133 auto buf = static_cast<uint32_t *>(data);
134 return NewData<T>(buf, size);
135 }
136 case kNumberTypeUInt64: {
137 auto buf = static_cast<uint64_t *>(data);
138 return NewData<T>(buf, size);
139 }
140 case kNumberTypeFloat16: {
141 auto buf = static_cast<float16 *>(data);
142 return NewData<T>(buf, size);
143 }
144 case kNumberTypeFloat32: {
145 auto buf = static_cast<float *>(data);
146 return NewData<T>(buf, size);
147 }
148 case kNumberTypeFloat64: {
149 auto buf = static_cast<double *>(data);
150 return NewData<T>(buf, size);
151 }
152 case kNumberTypeComplex64: {
153 auto buf = static_cast<ComplexStorage<float> *>(data);
154 return NewData<T>(buf, size);
155 }
156 case kNumberTypeComplex128: {
157 auto buf = static_cast<ComplexStorage<double> *>(data);
158 return NewData<T>(buf, size);
159 }
160 case kObjectTypeString: {
161 auto buf = static_cast<uint8_t *>(data);
162 return NewData<T>(buf, size);
163 }
164 default:
165 break;
166 }
167 MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
168 }
169
170 template <typename T>
CopyData(const ShapeVector & shape,void * const data,size_t data_len)171 std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
172 size_t size = SizeOf(shape);
173 if (size * sizeof(T) != data_len) {
174 MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
175 << " item size " << sizeof(T);
176 }
177 auto buf = static_cast<T *>(data);
178 return NewData<T>(buf, size);
179 }
180
181 // Tensor data implementation.
182 template <typename T>
183 class TensorDataImpl : public TensorData {
184 public:
TensorDataImpl(const ShapeVector & shape)185 explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
186 ~TensorDataImpl() = default;
187
TensorDataImpl(const ShapeVector & shape,void * data,size_t data_len)188 TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
189 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
190
TensorDataImpl(const ShapeVector & shape,void * data,TypeId data_type)191 TensorDataImpl(const ShapeVector &shape, void *data, TypeId data_type)
192 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
193
194 template <typename U>
TensorDataImpl(const ShapeVector & shape,const U * input,size_t size)195 TensorDataImpl(const ShapeVector &shape, const U *input, size_t size)
196 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(input, size)) {}
197
198 template <typename Scalar>
TensorDataImpl(const ShapeVector & shape,Scalar scalar)199 TensorDataImpl(const ShapeVector &shape, Scalar scalar)
200 : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(scalar)) {}
201
size() const202 ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
203
itemsize() const204 ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
205
nbytes() const206 ssize_t nbytes() const override { return size() * itemsize(); }
207
ndim() const208 ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
209
data()210 void *data() override {
211 if (data_ == nullptr) {
212 // Lazy allocation.
213 data_ = std::make_unique<T[]>(data_size_);
214 }
215 return data_.get();
216 }
217
const_data() const218 const void *const_data() const override {
219 // May return nullptr if data not initialized.
220 return data_.get();
221 }
222
equals(const TensorData & other) const223 bool equals(const TensorData &other) const override {
224 auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
225 if (ptr == nullptr) {
226 // Not same type, compare data byte by byte.
227 return TensorData::equals(other);
228 }
229 if (ptr == this) {
230 return true;
231 }
232 if (data_ == nullptr || ptr->data_ == nullptr) {
233 return false;
234 }
235 return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
236 std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
237 }
238
ToString(const TypeId type,const ShapeVector & shape,bool use_comma) const239 std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
240 constexpr auto valid =
241 std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
242 std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
243 std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
244 std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value ||
245 std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value;
246 static_assert(valid, "Type is invalid");
247 if (data_size_ == 0) {
248 return "";
249 }
250 if (data_ == nullptr) {
251 return "<uninitialized>";
252 }
253
254 std::ostringstream ss;
255 if (data_size_ == 1 && ndim_ == 0) { // Scalar
256 OutputDataString(ss, 0, 0, 1, false, 0);
257 return ss.str();
258 }
259
260 int num_width = 0;
261 ssize_t cursor = 0;
262 SummaryStringRecursive(ss, shape, &cursor, 0, use_comma, &num_width);
263 return ProcessPlaceholder(ss, num_width);
264 }
265
266 private:
OutputFloatDataString(std::ostringstream & ss,bool isScalar,const T & value) const267 void OutputFloatDataString(std::ostringstream &ss, bool isScalar, const T &value) const {
268 if (isScalar) {
269 ss << value;
270 } else {
271 // The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
272 const int width = std::is_same<T, float16>::value ? 11 : 15;
273 // The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
274 const int precision = std::is_same<T, float16>::value ? 4 : 8;
275 ss << std::setw(width) << std::setprecision(precision) << std::setiosflags(std::ios::scientific | std::ios::right)
276 << value;
277 }
278 }
279
OutputBoolDataString(std::ostringstream & ss,bool isScalar,const T & value) const280 void OutputBoolDataString(std::ostringstream &ss, bool isScalar, const T &value) const {
281 if (isScalar) {
282 ss << (value ? "True" : "False");
283 } else {
284 constexpr int bool_max_width = sizeof("False") - 1;
285 ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
286 }
287 }
288
OutputOtherDataString(std::ostringstream & ss,bool isScalar,const T & value,int * max_width) const289 void OutputOtherDataString(std::ostringstream &ss, bool isScalar, const T &value, int *max_width) const {
290 if (isScalar) {
291 ss << value;
292 } else {
293 // Add a padding string before the number, such as "###123", for subsequent replacement.
294 const int width = GetNumLength(value);
295 *max_width = std::max(*max_width, width);
296 std::string pad(width, '#');
297 ss << pad;
298 if constexpr (std::is_same<T, uint8_t>::value) {
299 ss << static_cast<uint16_t>(value);
300 } else if constexpr (std::is_same<T, int8_t>::value) {
301 ss << static_cast<int16_t>(value);
302 } else {
303 ss << value;
304 }
305 }
306 }
307
OutputDataString(std::ostringstream & ss,ssize_t cursor,ssize_t start,ssize_t end,bool use_comma,int * max_width) const308 void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma,
309 int *max_width) const {
310 const bool isScalar = ndim_ == 0 && end - start == 1;
311 constexpr auto isBool = std::is_same<T, bool>::value;
312 constexpr auto isFloat =
313 std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
314 constexpr auto isComplex =
315 std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value;
316 constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
317 for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
318 const auto value = data_[cursor + i];
319 if constexpr (isComplex) {
320 ss << value;
321 } else if constexpr (isFloat) {
322 OutputFloatDataString(ss, isScalar, value);
323 } else if (isBool) {
324 OutputBoolDataString(ss, isScalar, value);
325 } else {
326 OutputOtherDataString(ss, isScalar, value, max_width);
327 }
328 if (!isScalar && i != end - 1) {
329 if (use_comma) {
330 ss << ',';
331 }
332 ss << ' ';
333 }
334 if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) {
335 // Add a line feed every {threshold of type} for 1D tensor.
336 ss << '\n' << ' ';
337 }
338 }
339 }
340
SummaryStringRecursive(std::ostringstream & ss,const ShapeVector & shape,ssize_t * cursor,ssize_t depth,bool use_comma,int * max_width) const341 void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth,
342 bool use_comma, int *max_width) const {
343 if (depth >= static_cast<ssize_t>(ndim_)) {
344 return;
345 }
346 ss << '[';
347 if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension
348 ssize_t num = shape[depth];
349 if (num > kThreshold && ndim_ > 1) {
350 OutputDataString(ss, *cursor, 0, kThreshold >> 1, use_comma, max_width);
351 ss << ' ' << kEllipsis << ' ';
352 OutputDataString(ss, *cursor, num - (kThreshold >> 1), num, use_comma, max_width);
353 } else {
354 OutputDataString(ss, *cursor, 0, num, use_comma, max_width);
355 }
356 *cursor += num;
357 } else { // Middle dimension
358 ssize_t num = shape[depth];
359 // Handle the first half.
360 for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold >> 1), num); i++) {
361 if (i > 0) {
362 if (use_comma) {
363 ss << ',';
364 }
365 ss << '\n';
366 ss << std::setw(depth + 1) << ' '; // Add the indent.
367 }
368 SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma, max_width);
369 }
370 // Handle the ignored part.
371 if (num > kThreshold) {
372 if (use_comma) {
373 ss << ',';
374 }
375 ss << '\n';
376 ss << std::setw(depth + 1) << ' '; // Add the indent.
377 ss << kEllipsis;
378 // Ignored at this layer.
379 ssize_t ignored = shape[depth + 1];
380 const size_t offset = 2;
381 for (ssize_t i = depth + offset; i < static_cast<ssize_t>(ndim_); i++) {
382 ignored *= shape[i];
383 }
384 // Multiple with ignored layers number.
385 ignored *= num - kThreshold;
386 *cursor += ignored;
387 }
388 // Handle the second half.
389 if (num > (kThreshold >> 1)) {
390 ssize_t iter_times =
391 std::min(static_cast<ssize_t>(num - (kThreshold >> 1)), static_cast<ssize_t>(kThreshold >> 1));
392 for (ssize_t i = 0; i < iter_times; i++) {
393 if (use_comma && (i != 0 || num <= kThreshold)) { // Not just after ignored part || Not handle ignored part
394 ss << ',';
395 }
396 ss << '\n';
397 ss << std::setw(depth + 1) << ' '; // Add the indent.
398 SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma, max_width);
399 }
400 }
401 }
402 ss << ']';
403 }
404
ProcessPlaceholder(std::ostringstream & ss,int max_width) const405 std::string ProcessPlaceholder(std::ostringstream &ss, int max_width) const {
406 std::string str = ss.str();
407 if constexpr (std::is_same<T, bool>::value || std::is_same<T, float16>::value || std::is_same<T, float>::value ||
408 std::is_same<T, double>::value) {
409 return str;
410 }
411 // Replace # with placeholder.
412 size_t index = str.find('#');
413 while (index != std::string::npos) {
414 size_t pos = index;
415 while (str[pos] == '#') {
416 pos++;
417 }
418 size_t len = pos - index;
419 std::string space(max_width - SizeToInt(len), ' ');
420 str = str.replace(index, len, space);
421 index = str.find('#', index);
422 }
423 return str;
424 }
425
GetNumLength(const T & num) const426 int GetNumLength(const T &num) const {
427 T value = num;
428 int count = 0;
429 if (value <= 0) { // Add the length of '-' when value < 0.
430 count++;
431 }
432 while (value != 0) {
433 value /= 10;
434 count++;
435 }
436 return count;
437 }
438
439 size_t ndim_{0};
440 size_t data_size_{0};
441 std::unique_ptr<T[]> data_;
442 };
443
444 template <typename... Args>
MakeTensorData(TypeId data_type,const ShapeVector & shape,const Args...args)445 TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const Args... args) {
446 switch (data_type) {
447 case kNumberTypeBool:
448 return std::make_shared<TensorDataImpl<bool>>(shape, args...);
449 case kNumberTypeUInt8:
450 return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
451 case kNumberTypeInt8:
452 return std::make_shared<TensorDataImpl<int8_t>>(shape, args...);
453 case kNumberTypeInt16:
454 return std::make_shared<TensorDataImpl<int16_t>>(shape, args...);
455 case kNumberTypeInt32:
456 return std::make_shared<TensorDataImpl<int32_t>>(shape, args...);
457 case kNumberTypeInt64:
458 return std::make_shared<TensorDataImpl<int64_t>>(shape, args...);
459 case kNumberTypeUInt16:
460 return std::make_shared<TensorDataImpl<uint16_t>>(shape, args...);
461 case kNumberTypeUInt32:
462 return std::make_shared<TensorDataImpl<uint32_t>>(shape, args...);
463 case kNumberTypeUInt64:
464 return std::make_shared<TensorDataImpl<uint64_t>>(shape, args...);
465 case kNumberTypeFloat16:
466 return std::make_shared<TensorDataImpl<float16>>(shape, args...);
467 case kNumberTypeFloat:
468 return std::make_shared<TensorDataImpl<float>>(shape, args...);
469 case kNumberTypeFloat32:
470 return std::make_shared<TensorDataImpl<float>>(shape, args...);
471 case kNumberTypeFloat64:
472 return std::make_shared<TensorDataImpl<double>>(shape, args...);
473 case kNumberTypeComplex64:
474 return std::make_shared<TensorDataImpl<ComplexStorage<float>>>(shape, args...);
475 case kNumberTypeComplex128:
476 return std::make_shared<TensorDataImpl<ComplexStorage<double>>>(shape, args...);
477 case kObjectTypeString:
478 return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
479 case kObjectTypeTensorType:
480 return std::make_shared<TensorDataImpl<int>>(shape, args...);
481 default:
482 break;
483 }
484 MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
485 }
486
Tensor(const Tensor & tensor)487 Tensor::Tensor(const Tensor &tensor)
488 : MetaTensor(tensor),
489 init_flag_(tensor.init_flag_),
490 data_(tensor.data_),
491 id_(tensor.id_),
492 event_(tensor.event_),
493 need_wait_(tensor.need_wait_),
494 sync_status_(tensor.sync_status_),
495 device_sync_(tensor.device_sync_),
496 need_release_device_mem_(tensor.need_release_device_mem_),
497 cache_enable_(tensor.cache_enable_),
498 cache_tensor_ptr_(tensor.cache_tensor_ptr_),
499 hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
500 padding_type_(tensor.padding_type()),
501 device_event_(tensor.device_event_) {}
502
Tensor(const Tensor & tensor,TypeId data_type)503 Tensor::Tensor(const Tensor &tensor, TypeId data_type)
504 : MetaTensor(data_type, tensor.shape_),
505 init_flag_(tensor.init_flag_),
506 data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
507 id_(tensor.id_),
508 event_(tensor.event_),
509 need_wait_(tensor.need_wait_),
510 sync_status_(tensor.sync_status_),
511 device_sync_(tensor.device_sync_),
512 need_release_device_mem_(tensor.need_release_device_mem_),
513 cache_enable_(tensor.cache_enable_),
514 cache_tensor_ptr_(tensor.cache_tensor_ptr_),
515 hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
516 padding_type_(tensor.padding_type()),
517 device_event_(tensor.device_event_) {}
518
Tensor(TypeId data_type,const ShapeVector & shape,TensorDataPtr data)519 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
520 : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
521
Tensor(TypeId data_type,const ShapeVector & shape)522 Tensor::Tensor(TypeId data_type, const ShapeVector &shape)
523 : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {}
524
Tensor(TypeId data_type,const ShapeVector & shape,void * data,size_t data_len)525 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
526 : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {}
527
Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)528 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
529 : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {}
530
Tensor(const std::vector<int64_t> & input,const TypePtr & data_type)531 Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type)
532 : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}),
533 data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
534 id_(MakeId()) {}
535
Tensor(const std::vector<double> & input,const TypePtr & data_type)536 Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type)
537 : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast<int>(input.size())}),
538 data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
539 id_(MakeId()) {}
540
Tensor(int64_t input,const TypePtr & data_type)541 Tensor::Tensor(int64_t input, const TypePtr &data_type)
542 : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}),
543 data_(MakeTensorData(data_type_, {}, input)),
544 id_(MakeId()) {}
545
Tensor(double input,const TypePtr & data_type)546 Tensor::Tensor(double input, const TypePtr &data_type)
547 : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}),
548 data_(MakeTensorData(data_type_, {}, input)),
549 id_(MakeId()) {}
550
Tensor(uint64_t input,const TypePtr & data_type)551 Tensor::Tensor(uint64_t input, const TypePtr &data_type)
552 : MetaTensor(TypeIdOf(data_type, kNumberTypeUInt64), {}),
553 data_(MakeTensorData(data_type_, {}, input)),
554 id_(MakeId()) {}
555
Tensor(bool input,const TypePtr & data_type)556 Tensor::Tensor(bool input, const TypePtr &data_type)
557 : MetaTensor(TypeIdOf(data_type, kNumberTypeBool), {}),
558 data_(MakeTensorData(data_type_, {}, input)),
559 id_(MakeId()) {}
560
operator ==(const Tensor & tensor) const561 bool Tensor::operator==(const Tensor &tensor) const {
562 return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
563 }
564
ValueEqual(const Tensor & tensor) const565 bool Tensor::ValueEqual(const Tensor &tensor) const {
566 return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
567 }
568
569 // assign value to this tensor
AssignValue(const Tensor & tensor)570 Tensor &Tensor::AssignValue(const Tensor &tensor) {
571 if (this != &tensor) {
572 MetaTensor::operator=(tensor);
573 device_sync_ = tensor.device_sync_;
574 need_release_device_mem_ = tensor.need_release_device_mem_;
575 data_ = tensor.data_;
576 id_ = tensor.id_;
577 event_ = tensor.event_;
578 need_wait_ = tensor.need_wait_;
579 sync_status_ = tensor.sync_status_;
580 padding_type_ = tensor.padding_type_;
581 device_event_ = tensor.device_event_;
582 }
583 return *this;
584 }
585
ToAbstract()586 abstract::AbstractBasePtr Tensor::ToAbstract() {
587 auto tens = shared_from_base<Tensor>();
588 auto dtype = tens->Dtype();
589 if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
590 MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
591 }
592 auto tensor_shape = tens->shape();
593 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
594 // if is parameter always no value.
595 if (is_parameter_) {
596 auto param_name = param_info_->name();
597 auto ref_key = std::make_shared<RefKey>(param_name);
598 auto abs_ref_key = ref_key->ToAbstract();
599 abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
600 } else {
601 abs_tensor->set_value(shared_from_base<Tensor>());
602 }
603 return abs_tensor;
604 }
605
GetShapeAndDataTypeInfo() const606 std::string Tensor::GetShapeAndDataTypeInfo() const {
607 std::ostringstream buf;
608 buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
609 return buf.str();
610 }
611
ToStringInternal(int limit_size) const612 std::string Tensor::ToStringInternal(int limit_size) const {
613 std::ostringstream buf;
614 auto dtype = Dtype();
615 MS_EXCEPTION_IF_NULL(dtype);
616 buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", value=";
617 if (limit_size <= 0 || DataSize() < limit_size) {
618 // Only print data for small tensor.
619 buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false);
620 } else {
621 buf << " [...]";
622 }
623 if (is_parameter_) {
624 buf << ", name=" << param_info_->name();
625 }
626 buf << ")";
627 return buf.str();
628 }
629
ToString() const630 std::string Tensor::ToString() const {
631 constexpr int small_tensor_size = 30;
632 return ToStringInternal(small_tensor_size);
633 }
634
ToStringNoLimit() const635 std::string Tensor::ToStringNoLimit() const { return ToStringInternal(0); }
636
ToStringRepr() const637 std::string Tensor::ToStringRepr() const {
638 std::ostringstream buf;
639 auto dtype = Dtype();
640 MS_EXCEPTION_IF_NULL(dtype);
641 buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
642 << ", value=" << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, true) << ')';
643 return buf.str();
644 }
645
data_sync(bool need_wait) const646 void Tensor::data_sync(bool need_wait) const {
647 if (need_wait) {
648 Wait();
649 }
650 if (device_sync_ == nullptr) {
651 return;
652 }
653 std::vector<size_t> shape_tmp;
654 (void)std::transform(shape().begin(), shape().end(), std::back_inserter(shape_tmp), IntToSize);
655 auto size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(data_type());
656 auto address = device_sync_;
657 if (size != 0 && !address->SyncDeviceToHost(shape(), size, data_type(), data_c())) {
658 MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
659 }
660 sync_status_ = kNeedSyncHostToDevice;
661 }
662
set_data_type(const TypeId data_type)663 TypeId Tensor::set_data_type(const TypeId data_type) {
664 if (data_type != data_type_) {
665 data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_);
666 return MetaTensor::set_data_type(data_type);
667 }
668 return data_type;
669 }
670 } // namespace tensor
671 } // namespace mindspore
672