1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/tensor.h"
18
19 #include <cstdint>
20 #include <exception>
21 #include <iomanip>
22 #include <functional>
23 #include <memory>
24 #include <utility>
25 #include <algorithm>
26 #include <map>
27 #include <vector>
28 #include "mindapi/base/type_id.h"
29 #include "abstract/utils.h"
30 #include "abstract/abstract_value.h"
31 #include "base/complex_storage.h"
32 #include "utils/log_adapter.h"
33 #include "mindspore/ccsrc/include/common/utils/convert_utils.h"
34 #include "utils/shape_utils.h"
35 #include "utils/ordered_set.h"
36 #include "utils/system/env.h"
37 #include "utils/temp_file_manager.h"
38
39 namespace mindspore {
40 namespace tensor {
41 // TensorSubData is the base class to provide tensor data as a segment from an owner tensor data.
42 class TensorSubData : public TensorData {
43 public:
TensorSubData(const TensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)44 TensorSubData(const TensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
45 : data_owner_(data_owner), data_offset_(offset), data_size_(data_size), ndim_(ndim) {}
TensorSubData(const BaseTensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)46 TensorSubData(const BaseTensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
47 : data_owner_(data_owner), data_offset_(offset), data_size_(data_size), ndim_(ndim) {}
48
49 ~TensorSubData() override = default;
50
size() const51 ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
52
nbytes() const53 ssize_t nbytes() const override { return size() * itemsize(); }
54
ndim() const55 ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
56
is_sub_data() const57 bool is_sub_data() const override { return true; }
58
has_sub_data() const59 bool has_sub_data() const override { return false; }
60
data()61 void *data() override {
62 // Set data initialized if data() is called.
63 data_initialized_ = true;
64 auto start = static_cast<uint8_t *>(data_owner_->data().data());
65 return static_cast<void *>(start + data_offset_);
66 }
67
const_data() const68 const void *const_data() const override {
69 if (!data_initialized_) {
70 // Return nullptr if data not initialized.
71 return nullptr;
72 }
73 auto start = static_cast<uint8_t *>(data_owner_->data().data());
74 return static_cast<void *>(start + data_offset_);
75 }
76
77 // Get the owner Tensor.
GetOwner() const78 const BaseTensorPtr &GetOwner() const { return data_owner_; }
79
80 // Data offset in bytes.
data_offset() const81 size_t data_offset() const { return data_offset_; }
82
83 protected:
84 const BaseTensorPtr data_owner_;
85 size_t data_offset_{0};
86 size_t data_size_{0};
87 size_t ndim_{0};
88 bool data_initialized_{false};
89 };
90
91 // TensorSubDataImpl implements methods that rely on T.
92 template <typename T>
93 class TensorSubDataImpl : public TensorSubData {
94 public:
TensorSubDataImpl(const TensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)95 TensorSubDataImpl(const TensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
96 : TensorSubData(data_owner, offset, data_size, ndim) {}
TensorSubDataImpl(const BaseTensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)97 TensorSubDataImpl(const BaseTensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
98 : TensorSubData(data_owner, offset, data_size, ndim) {}
99
100 ~TensorSubDataImpl() override = default;
101
itemsize() const102 ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
103
ToString(TypeId type,const ShapeVector & shape,bool use_comma) const104 std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override {
105 TensorStringifier<T> stringifier{static_cast<const T *>(const_data()), data_size_, ndim_};
106 return stringifier.ToString(type, shape, use_comma);
107 }
108 };
109
MakeTensorSubData(const BaseTensorPtr & owner,size_t offset,const TensorDataPtr & data)110 TensorDataPtr MakeTensorSubData(const BaseTensorPtr &owner, size_t offset, const TensorDataPtr &data) {
111 if (data->nbytes() == 0) {
112 MS_LOG(INTERNAL_EXCEPTION) << "Tensor data size is 0.";
113 }
114 auto sub_data =
115 tensor::MakeTensorData<TensorSubDataImpl>(owner->data_type(), owner, offset, data->size(), data->ndim());
116 // If tensor data is initialized, copy it.
117 if (data->const_data() != nullptr) {
118 CopyTensorData(sub_data, data);
119 }
120 return sub_data;
121 }
122
123 // TensorChunk holds info for a chunk.
124 struct TensorChunk {
125 size_t size{0}; // chunk size in the number of elements.
126 size_t bytes{0}; // chunk size in bytes.
127 std::vector<BaseTensorPtr> tensors; // tensors belong to this chunk.
128 };
129
normalize_type(TypeId type_id)130 static TypeId normalize_type(TypeId type_id) {
131 if (type_id == kNumberTypeFloat) {
132 // kNumberTypeFloat is an alias of kNumberTypeFloat32.
133 return kNumberTypeFloat32;
134 }
135 return type_id;
136 }
137
Tensor(const Tensor & tensor)138 Tensor::Tensor(const Tensor &tensor)
139 : BaseTensor(tensor),
140 init_flag_(tensor.init_flag_),
141 need_release_device_mem_(tensor.need_release_device_mem_),
142 cache_enable_(tensor.cache_enable_),
143 cache_tensor_ptr_(tensor.cache_tensor_ptr_),
144 hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
145 pin_mem_register_(tensor.pin_mem_register_),
146 compression_type_(tensor.compression_type_),
147 tensor_name_(tensor.tensor_name_),
148 device_info_(tensor.device_info_) {}
149
Tensor(const Tensor & tensor,TypeId data_type)150 Tensor::Tensor(const Tensor &tensor, TypeId data_type)
151 : BaseTensor(tensor, data_type),
152 init_flag_(tensor.init_flag_),
153 need_release_device_mem_(tensor.need_release_device_mem_),
154 cache_enable_(tensor.cache_enable_),
155 cache_tensor_ptr_(tensor.cache_tensor_ptr_),
156 hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
157 pin_mem_register_(tensor.pin_mem_register_),
158 compression_type_(tensor.compression_type_),
159 tensor_name_(tensor.tensor_name_),
160 device_info_(tensor.device_info_) {}
161
Tensor(const BaseTensor & tensor,TypeId data_type)162 Tensor::Tensor(const BaseTensor &tensor, TypeId data_type) : BaseTensor(tensor, data_type) {}
163
Tensor(const BaseTensor & base_tensor)164 Tensor::Tensor(const BaseTensor &base_tensor) : BaseTensor(base_tensor) {}
165
operator =(const Tensor & tensor)166 Tensor &Tensor::operator=(const Tensor &tensor) {
167 if (this == &tensor) {
168 return *this;
169 }
170 BaseTensor::operator=(tensor);
171 init_flag_ = tensor.init_flag_;
172 need_release_device_mem_ = tensor.need_release_device_mem_;
173 cache_enable_ = tensor.cache_enable_;
174 cache_tensor_ptr_ = tensor.cache_tensor_ptr_;
175 hashmap_tensor_ptr_ = tensor.hashmap_tensor_ptr_;
176 pin_mem_register_ = tensor.pin_mem_register_;
177 compression_type_ = tensor.compression_type_;
178 tensor_name_ = tensor.tensor_name_;
179 adapter_flag_ = tensor.adapter_flag_;
180 cast_dtype_ = tensor.cast_dtype_;
181 graph_output_ = tensor.graph_output_;
182 quant_params_ = tensor.quant_params_;
183 updated_by_device_ = tensor.updated_by_device_;
184 device_info_ = tensor.device_info_;
185 return *this;
186 }
187
Tensor(TypeId data_type,const ShapeVector & shape,TensorDataPtr data)188 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) : BaseTensor(data_type, shape, data) {}
189
Tensor(TypeId data_type,const ShapeVector & shape)190 Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : BaseTensor(data_type, shape) {}
191
Tensor(TypeId data_type,const ShapeVector & shape,void * data,size_t data_len)192 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
193 : BaseTensor(data_type, shape, data, data_len) {}
194
Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)195 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
196 : BaseTensor(data_type, shape, data, src_data_type) {}
197
Tensor(const std::vector<int64_t> & input,const TypePtr & data_type)198 Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
199
Tensor(const std::vector<int32_t> & input,const TypePtr & data_type)200 Tensor::Tensor(const std::vector<int32_t> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
201
Tensor(const std::vector<double> & input,const TypePtr & data_type)202 Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
203
Tensor(const std::vector<float> & input,const TypePtr & data_type)204 Tensor::Tensor(const std::vector<float> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
205
Tensor(int64_t input,const TypePtr & data_type)206 Tensor::Tensor(int64_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
207
Tensor(int32_t input,const TypePtr & data_type)208 Tensor::Tensor(int32_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
209
Tensor(int16_t input,const TypePtr & data_type)210 Tensor::Tensor(int16_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
211
Tensor(int8_t input,const TypePtr & data_type)212 Tensor::Tensor(int8_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
213
Tensor(double input,const TypePtr & data_type)214 Tensor::Tensor(double input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
215
Tensor(float input,const TypePtr & data_type)216 Tensor::Tensor(float input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
217
Tensor(float16 input,const TypePtr & data_type)218 Tensor::Tensor(float16 input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
219 #ifndef KERNEL_EXECUTOR_ANDROID
Tensor(bfloat16 input,const TypePtr & data_type)220 Tensor::Tensor(bfloat16 input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
221 #endif
Tensor(uint64_t input,const TypePtr & data_type)222 Tensor::Tensor(uint64_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
223
Tensor(uint32_t input,const TypePtr & data_type)224 Tensor::Tensor(uint32_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
225
Tensor(uint16_t input,const TypePtr & data_type)226 Tensor::Tensor(uint16_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
227
Tensor(uint8_t input,const TypePtr & data_type)228 Tensor::Tensor(uint8_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
229
Tensor(bool input,const TypePtr & data_type)230 Tensor::Tensor(bool input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
231
Tensor(TypeId data_type,size_t data_size)232 Tensor::Tensor(TypeId data_type, size_t data_size) : BaseTensor(data_type, data_size) {}
233
Tensor(TypeId origin_data_type,const ShapeVector & shape,size_t compression_data_size,TensorCompressionType compression_type)234 Tensor::Tensor(TypeId origin_data_type, const ShapeVector &shape, size_t compression_data_size,
235 TensorCompressionType compression_type)
236 : BaseTensor(origin_data_type, shape, compression_data_size, compression_type) {
237 compression_type_ = compression_type;
238 }
239
~Tensor()240 Tensor::~Tensor() {
241 try {
242 UnPinMemory();
243 pin_mem_register_ = nullptr;
244 } catch (const std::exception &e) {
245 MS_LOG(ERROR) << "Exception when destruct tensor. Error info " << e.what();
246 }
247 }
248
operator ==(const Tensor & tensor) const249 bool Tensor::operator==(const Tensor &tensor) const {
250 return (&tensor == this || (BaseTensor::operator==(tensor) && data_ == tensor.data_));
251 }
252
253 // Assign value to this tensor.
AssignValue(const Tensor & tensor)254 Tensor &Tensor::AssignValue(const Tensor &tensor) {
255 if (this != &tensor) {
256 BaseTensor::AssignValue(tensor);
257 device_info_ = tensor.device_info_;
258 need_release_device_mem_ = tensor.need_release_device_mem_;
259
260 // Need execute callback when update host value of Tensor.
261 ExecuteUpdateValueCallback();
262 }
263 return *this;
264 }
265
ToAbstract()266 abstract::AbstractBasePtr Tensor::ToAbstract() {
267 auto abs_tensor = BaseTensor::ToAbstract()->cast<abstract::AbstractTensorPtr>();
268 if (is_adapter()) {
269 abs_tensor->set_is_adapter(true);
270 }
271 return abs_tensor;
272 }
273
data_sync(bool need_wait) const274 void Tensor::data_sync(bool need_wait) const { BaseTensor::data_sync(need_wait); }
275
ExecuteUpdateValueCallback() const276 void Tensor::ExecuteUpdateValueCallback() const {
277 if (update_value_callback_ != nullptr) {
278 update_value_callback_(this);
279 }
280 }
281
SetDeviceInfo(const std::string & format,const TypePtr & data_type,const std::string & host_format)282 void Tensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type, const std::string &host_format) {
283 DeviceInfo info(format, data_type, host_format);
284 set_device_info(info);
285 }
286
data_sync_directly(const DeviceSync * const device_sync,bool need_wait) const287 void Tensor::data_sync_directly(const DeviceSync *const device_sync, bool need_wait) const {
288 if (need_wait) {
289 ExecuteLazyTask();
290 }
291 if (device_sync == nullptr) {
292 return;
293 }
294 MS_EXCEPTION_IF_NULL(data_);
295 if (data_->is_sub_data()) {
296 return;
297 }
298
299 std::vector<size_t> shape_tmp;
300 (void)std::transform(shape().begin(), shape().end(), std::back_inserter(shape_tmp), IntToSize);
301 auto size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(data_type());
302 if (size != 0 && !device_sync->SyncDeviceToHost(shape(), size, data_type(), data_c())) {
303 MS_LOG(INTERNAL_EXCEPTION) << "SyncDeviceToHost failed.";
304 }
305 sync_status_ = kNeedSyncHostToDevice;
306 }
307
Offload(const std::string & file_path)308 bool Tensor::Offload(const std::string &file_path) {
309 if (file_path.empty()) {
310 return false;
311 }
312
313 auto fs = mindspore::system::Env::GetFileSystem();
314 MS_EXCEPTION_IF_NULL(fs);
315 MS_EXCEPTION_IF_NULL(data_);
316 auto data_ptr = data_->data();
317 auto file = fs->CreateWriteFile(file_path);
318 MS_EXCEPTION_IF_NULL(file);
319 TempFileManager::GetInstance().Register(file_path);
320 bool success = file->PWrite(data_ptr, LongToSize(data_->nbytes()), 0);
321 if (!file->Close()) {
322 MS_LOG(WARNING) << "Close tensor file: " << file_path << " failed!";
323 }
324 if (!success) {
325 MS_LOG(WARNING) << "Tensor write data to file: " << file_path << " failed!";
326 return false;
327 }
328
329 if (file_path == GetOffloadFilePath()) {
330 data_->set_file_path("");
331 }
332
333 data_ = tensor::MakeTensorData(data_type_, shape_);
334 MS_EXCEPTION_IF_NULL(data_);
335 data_->set_file_path(file_path);
336 return true;
337 }
338
GetOffloadFilePath() const339 const std::string Tensor::GetOffloadFilePath() const {
340 if (data_ == nullptr) {
341 return "";
342 }
343 return data_->file_path();
344 }
345
GetChunkOffset() const346 std::pair<void *, size_t> Tensor::GetChunkOffset() const {
347 // Get sub-data.
348 auto sub_data = std::dynamic_pointer_cast<TensorSubData>(data_ptr());
349 if (sub_data == nullptr) {
350 return {nullptr, 0};
351 }
352 // Get owner tensor from sub-data.
353 auto owner_tensor = sub_data->GetOwner();
354 MS_EXCEPTION_IF_NULL(owner_tensor);
355 return {owner_tensor->data_c(), sub_data->data_offset()};
356 }
357
GroupingTensors(const TensorPtrList & tensors,size_t fusion_size)358 static std::map<TypeId, std::vector<TensorChunk>> GroupingTensors(const TensorPtrList &tensors, size_t fusion_size) {
359 // Use std::map to keep order by type id.
360 std::map<TypeId, std::vector<TensorChunk>> group_info;
361 for (auto &tensor : tensors) {
362 MS_EXCEPTION_IF_NULL(tensor);
363 auto tensor_bytes = static_cast<size_t>(tensor->data().nbytes());
364 if ((fusion_size != 0) && (tensor_bytes > fusion_size)) {
365 MS_LOG(EXCEPTION) << "Fusion size " << fusion_size << " is too small for a tensor size " << tensor_bytes << ".";
366 }
367 auto &chunks = group_info[normalize_type(tensor->data_type())];
368 if (chunks.empty()) {
369 (void)chunks.emplace_back();
370 }
371 if ((fusion_size != 0) && (chunks.back().bytes + tensor_bytes > fusion_size)) {
372 (void)chunks.emplace_back();
373 }
374 auto &chunk = chunks.back();
375 chunk.size += tensor->DataSize();
376 chunk.bytes += tensor_bytes;
377 (void)chunk.tensors.emplace_back(tensor);
378 }
379 return group_info;
380 }
381
FlattenTensors(const TensorPtrList & tensors,size_t fusion_size)382 TensorPtrList Tensor::FlattenTensors(const TensorPtrList &tensors, size_t fusion_size) {
383 // Result tensor list.
384 TensorPtrList result_list;
385 // Grouping tensors by data type and fusion size.
386 auto group_info = GroupingTensors(tensors, fusion_size);
387 // Create chunk tensors and copy data to them.
388 for (auto &type_group : group_info) {
389 auto chunk_dtype = normalize_type(type_group.first);
390 for (auto &chunk : type_group.second) {
391 // Create chunk thensor as a lazy initialized tensor, the tensor data
392 // will be allocated when we begin to copy small tensors data into it.
393 auto chunk_tensor = std::make_shared<Tensor>(chunk_dtype, chunk.size);
394 // Reset and copy tensors data.
395 size_t offset = 0;
396 for (auto &tensor : chunk.tensors) {
397 auto sub_data = MakeTensorSubData(chunk_tensor, offset, tensor->data_ptr());
398 offset += static_cast<size_t>(sub_data->nbytes());
399 tensor->set_data(sub_data);
400 }
401 // Save chunk tensor to result list.
402 (void)result_list.emplace_back(std::move(chunk_tensor));
403 }
404 }
405 return result_list;
406 }
407
IsFlattened(const TensorPtrList & tensors)408 bool Tensor::IsFlattened(const TensorPtrList &tensors) {
409 // Tensor data is flattened if all tensors data are TensorSubData.
410 return std::all_of(tensors.begin(), tensors.end(), [](const TensorPtr &tensor) {
411 MS_EXCEPTION_IF_NULL(tensor);
412 auto data_ptr = tensor->data_ptr().get();
413 return dynamic_cast<TensorSubData *>(data_ptr) != nullptr;
414 });
415 }
416
GetFlattenedTensors(const TensorPtrList & tensors)417 TensorPtrList Tensor::GetFlattenedTensors(const TensorPtrList &tensors) {
418 // Use std::map to keep order by type id.
419 std::map<TypeId, OrderedSet<TensorPtr>> chunk_map;
420 for (auto &tensor : tensors) {
421 // Get sub-data.
422 auto sub_data = std::dynamic_pointer_cast<TensorSubData>(tensor->data_ptr());
423 if (sub_data == nullptr) {
424 MS_LOG(WARNING) << "Tensors are not flattened.";
425 return {};
426 }
427 // Get owner tensor from sub-data.
428 auto owner_tensor = std::dynamic_pointer_cast<Tensor>(sub_data->GetOwner());
429 MS_EXCEPTION_IF_NULL(owner_tensor);
430 // Add as chunk tensor by its data type.
431 auto chunk_dtype = normalize_type(tensor->data_type());
432 chunk_map[chunk_dtype].add(owner_tensor);
433 }
434 // Generate result tensor list.
435 TensorPtrList result_tensors;
436 for (auto &entry : chunk_map) {
437 auto &chunk_tensors = entry.second;
438 (void)result_tensors.insert(result_tensors.end(), chunk_tensors.begin(), chunk_tensors.end());
439 }
440 return result_tensors;
441 }
442
CheckStub()443 bool Tensor::CheckStub() {
444 #if defined(WITH_BACKEND)
445 return false;
446 #else
447 auto context_ptr = MsContext::GetInstance();
448 MS_EXCEPTION_IF_NULL(context_ptr);
449 std::string backend_name = context_ptr->backend_policy();
450 if (backend_name == "vm") {
451 return false;
452 }
453 return true;
454 #endif
455 }
456
GetFusionSize(const TensorPtrList & flat_tensors)457 size_t Tensor::GetFusionSize(const TensorPtrList &flat_tensors) {
458 size_t fusion_size = 0;
459 std::map<TypeId, size_t> type_groups;
460 for (auto &tensor : flat_tensors) {
461 MS_EXCEPTION_IF_NULL(tensor);
462 auto tensor_bytes = static_cast<size_t>(tensor->data().nbytes());
463 if (tensor_bytes > fusion_size) {
464 fusion_size = tensor_bytes;
465 }
466 ++type_groups[tensor->data_type()];
467 }
468 const bool only_one_chunk_for_each_type =
469 std::all_of(type_groups.begin(), type_groups.end(), [](auto const &e) { return e.second == 1; });
470 if (only_one_chunk_for_each_type) {
471 return 0;
472 }
473 return fusion_size;
474 }
475
is_persistent_data() const476 bool Tensor::is_persistent_data() const { return this->data().is_persistent_data(); }
477
PinMemory(PinnedMemRegister * pin_mem_register)478 void Tensor::PinMemory(PinnedMemRegister *pin_mem_register) {
479 if (pin_mem_register == nullptr) {
480 return;
481 }
482 pin_mem_register_ = pin_mem_register;
483 pin_mem_register_->RegisterPinnedMem(data_c(), Size());
484 }
485
UnPinMemory()486 void Tensor::UnPinMemory() {
487 if (pin_mem_register_ == nullptr) {
488 return;
489 }
490 pin_mem_register_->UnRegisterPinnedMem(data_c());
491 }
492
CSRTensor(const TensorPtr indptr,const TensorPtr indices,const TensorPtr values,const ShapeVector & shape)493 CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
494 : MetaSparseTensor(values->data_type(), shape), indptr_(indptr), indices_(indices), values_(values) {}
495
ToString() const496 std::string CSRTensor::ToString() const {
497 std::ostringstream buf;
498 MS_EXCEPTION_IF_NULL(values_);
499 MS_EXCEPTION_IF_NULL(indices_);
500 MS_EXCEPTION_IF_NULL(indptr_);
501 auto dtype = values_->Dtype();
502 values_->data_sync(true);
503 indices_->data_sync(true);
504 indptr_->data_sync(true);
505 buf << "CSRTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", indptr=";
506 buf << indptr_->ToString() << ", indices=" << indices_->ToString() << ", values=";
507 buf << values_->ToString() << ")";
508 return buf.str();
509 }
510
ToAbstract()511 abstract::AbstractBasePtr CSRTensor::ToAbstract() {
512 auto dtype = values_->Dtype();
513 if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
514 MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
515 }
516
517 auto indptr = indptr_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
518 auto indices = indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
519 auto values = values_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
520 std::vector<abstract::AbstractBasePtr> abstract_shape;
521 (void)std::transform(
522 shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
523 [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
524 auto shape = std::make_shared<abstract::AbstractTuple>(abstract_shape);
525 AbstractBasePtrList element_list{indptr, indices, values, shape};
526
527 return std::make_shared<abstract::AbstractCSRTensor>(element_list);
528 }
529
GetSizeAt(size_t index) const530 const size_t CSRTensor::GetSizeAt(size_t index) const {
531 if (index == kIndptrIdx) {
532 MS_EXCEPTION_IF_NULL(indptr_);
533 return indptr_->data().nbytes();
534 } else if (index == kIndicesIdx) {
535 MS_EXCEPTION_IF_NULL(indices_);
536 return indices_->data().nbytes();
537 } else if (index == kValuesIdx) {
538 MS_EXCEPTION_IF_NULL(values_);
539 return values_->data().nbytes();
540 } else if (index >= kIndicesIdx && index < kShapeIdx + shape().size()) {
541 return sizeof(int64_t);
542 }
543 MS_LOG(EXCEPTION) << "Invalid index: " << index << " for CSRTensor: " << ToString();
544 }
545
GetTensorAt(size_t index) const546 TensorPtr CSRTensor::GetTensorAt(size_t index) const {
547 if (index == kIndptrIdx) {
548 MS_EXCEPTION_IF_NULL(indptr_);
549 return indptr_;
550 } else if (index == kIndicesIdx) {
551 MS_EXCEPTION_IF_NULL(indices_);
552 return indices_;
553 } else if (index == kValuesIdx) {
554 MS_EXCEPTION_IF_NULL(values_);
555 return values_;
556 } else if (index >= kShapeIdx && index < kShapeIdx + shape().size()) {
557 return std::make_shared<tensor::Tensor>(shape_[index - kShapeIdx], TypeIdToType(kNumberTypeInt64));
558 }
559 MS_LOG(EXCEPTION) << "Invalid index: " << index << " for CSRTensor: " << ToString();
560 }
561
GetTensorAt(size_t index) const562 TensorPtr COOTensor::GetTensorAt(size_t index) const {
563 if (index == kIndicesIdx) {
564 MS_EXCEPTION_IF_NULL(indices_);
565 return indices_;
566 } else if (index == kValuesIdx) {
567 MS_EXCEPTION_IF_NULL(values_);
568 return values_;
569 } else if (index >= kShapeIdx && index < kShapeIdx + shape().size()) {
570 return std::make_shared<tensor::Tensor>(shape_[index - kShapeIdx], TypeIdToType(kNumberTypeInt64));
571 }
572 MS_LOG(EXCEPTION) << "Invalid index: " << index << " for COOTensor: " << ToString();
573 }
574
ToString() const575 std::string COOTensor::ToString() const {
576 std::ostringstream buf;
577 MS_EXCEPTION_IF_NULL(indices_);
578 MS_EXCEPTION_IF_NULL(values_);
579 indices_->data_sync(true);
580 values_->data_sync(true);
581 auto dtype = values_->Dtype();
582 buf << "COOTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
583 << ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")";
584 return buf.str();
585 }
586
ToAbstract()587 abstract::AbstractBasePtr COOTensor::ToAbstract() {
588 MS_EXCEPTION_IF_NULL(values_);
589 auto dtype = values_->Dtype();
590 if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
591 MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
592 }
593 MS_EXCEPTION_IF_NULL(indices_);
594 MS_EXCEPTION_IF_NULL(indices_->ToAbstract());
595 MS_EXCEPTION_IF_NULL(values_->ToAbstract());
596 auto indices = indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
597 auto values = values_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
598 std::vector<abstract::AbstractBasePtr> abstract_shape;
599 (void)std::transform(
600 shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
601 [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
602 auto shape = std::make_shared<abstract::AbstractTuple>(abstract_shape);
603 AbstractBasePtrList element_list{indices, values, shape};
604
605 return std::make_shared<abstract::AbstractCOOTensor>(element_list);
606 }
607
ToString() const608 std::string RowTensor::ToString() const {
609 std::ostringstream buf;
610 MS_EXCEPTION_IF_NULL(indices_);
611 MS_EXCEPTION_IF_NULL(values_);
612 auto dtype = values_->Dtype();
613 buf << "RowTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
614 << ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")";
615 return buf.str();
616 }
617
ToAbstract()618 abstract::AbstractBasePtr RowTensor::ToAbstract() {
619 auto dtype = values_->Dtype();
620 if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
621 MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
622 }
623 auto abs_sparse_tensor = std::make_shared<abstract::AbstractRowTensor>(dtype, shape_);
624 MS_EXCEPTION_IF_NULL(indices_);
625 MS_EXCEPTION_IF_NULL(indices_->ToAbstract());
626 MS_EXCEPTION_IF_NULL(values_->ToAbstract());
627 abs_sparse_tensor->set_indices(indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
628 abs_sparse_tensor->set_values(values_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
629
630 std::vector<abstract::AbstractBasePtr> abstract_shape;
631 (void)std::transform(
632 shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
633 [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
634 abs_sparse_tensor->set_dense_shape(std::make_shared<abstract::AbstractTuple>(abstract_shape));
635
636 return abs_sparse_tensor;
637 }
638 } // namespace tensor
639 } // namespace mindspore
640