• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/tensor.h"
18 
19 #include <cstdint>
20 #include <exception>
21 #include <iomanip>
22 #include <functional>
23 #include <memory>
24 #include <utility>
25 #include <algorithm>
26 #include <map>
27 #include <vector>
28 #include "mindapi/base/type_id.h"
29 #include "abstract/utils.h"
30 #include "abstract/abstract_value.h"
31 #include "base/complex_storage.h"
32 #include "utils/log_adapter.h"
33 #include "mindspore/ccsrc/include/common/utils/convert_utils.h"
34 #include "utils/shape_utils.h"
35 #include "utils/ordered_set.h"
36 #include "utils/system/env.h"
37 #include "utils/temp_file_manager.h"
38 
39 namespace mindspore {
40 namespace tensor {
41 // TensorSubData is the base class to provide tensor data as a segment from an owner tensor data.
42 class TensorSubData : public TensorData {
43  public:
TensorSubData(const TensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)44   TensorSubData(const TensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
45       : data_owner_(data_owner), data_offset_(offset), data_size_(data_size), ndim_(ndim) {}
TensorSubData(const BaseTensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)46   TensorSubData(const BaseTensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
47       : data_owner_(data_owner), data_offset_(offset), data_size_(data_size), ndim_(ndim) {}
48 
49   ~TensorSubData() override = default;
50 
size() const51   ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
52 
nbytes() const53   ssize_t nbytes() const override { return size() * itemsize(); }
54 
ndim() const55   ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
56 
is_sub_data() const57   bool is_sub_data() const override { return true; }
58 
has_sub_data() const59   bool has_sub_data() const override { return false; }
60 
data()61   void *data() override {
62     // Set data initialized if data() is called.
63     data_initialized_ = true;
64     auto start = static_cast<uint8_t *>(data_owner_->data().data());
65     return static_cast<void *>(start + data_offset_);
66   }
67 
const_data() const68   const void *const_data() const override {
69     if (!data_initialized_) {
70       // Return nullptr if data not initialized.
71       return nullptr;
72     }
73     auto start = static_cast<uint8_t *>(data_owner_->data().data());
74     return static_cast<void *>(start + data_offset_);
75   }
76 
77   // Get the owner Tensor.
GetOwner() const78   const BaseTensorPtr &GetOwner() const { return data_owner_; }
79 
80   // Data offset in bytes.
data_offset() const81   size_t data_offset() const { return data_offset_; }
82 
83  protected:
84   const BaseTensorPtr data_owner_;
85   size_t data_offset_{0};
86   size_t data_size_{0};
87   size_t ndim_{0};
88   bool data_initialized_{false};
89 };
90 
91 // TensorSubDataImpl implements methods that rely on T.
92 template <typename T>
93 class TensorSubDataImpl : public TensorSubData {
94  public:
TensorSubDataImpl(const TensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)95   TensorSubDataImpl(const TensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
96       : TensorSubData(data_owner, offset, data_size, ndim) {}
TensorSubDataImpl(const BaseTensorPtr & data_owner,size_t offset,size_t data_size,size_t ndim)97   TensorSubDataImpl(const BaseTensorPtr &data_owner, size_t offset, size_t data_size, size_t ndim)
98       : TensorSubData(data_owner, offset, data_size, ndim) {}
99 
100   ~TensorSubDataImpl() override = default;
101 
itemsize() const102   ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
103 
ToString(TypeId type,const ShapeVector & shape,bool use_comma) const104   std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override {
105     TensorStringifier<T> stringifier{static_cast<const T *>(const_data()), data_size_, ndim_};
106     return stringifier.ToString(type, shape, use_comma);
107   }
108 };
109 
MakeTensorSubData(const BaseTensorPtr & owner,size_t offset,const TensorDataPtr & data)110 TensorDataPtr MakeTensorSubData(const BaseTensorPtr &owner, size_t offset, const TensorDataPtr &data) {
111   if (data->nbytes() == 0) {
112     MS_LOG(INTERNAL_EXCEPTION) << "Tensor data size is 0.";
113   }
114   auto sub_data =
115     tensor::MakeTensorData<TensorSubDataImpl>(owner->data_type(), owner, offset, data->size(), data->ndim());
116   // If tensor data is initialized, copy it.
117   if (data->const_data() != nullptr) {
118     CopyTensorData(sub_data, data);
119   }
120   return sub_data;
121 }
122 
123 // TensorChunk holds info for a chunk.
124 struct TensorChunk {
125   size_t size{0};                      // chunk size in the number of elements.
126   size_t bytes{0};                     // chunk size in bytes.
127   std::vector<BaseTensorPtr> tensors;  // tensors belong to this chunk.
128 };
129 
normalize_type(TypeId type_id)130 static TypeId normalize_type(TypeId type_id) {
131   if (type_id == kNumberTypeFloat) {
132     // kNumberTypeFloat is an alias of kNumberTypeFloat32.
133     return kNumberTypeFloat32;
134   }
135   return type_id;
136 }
137 
Tensor(const Tensor & tensor)138 Tensor::Tensor(const Tensor &tensor)
139     : BaseTensor(tensor),
140       init_flag_(tensor.init_flag_),
141       need_release_device_mem_(tensor.need_release_device_mem_),
142       cache_enable_(tensor.cache_enable_),
143       cache_tensor_ptr_(tensor.cache_tensor_ptr_),
144       hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
145       pin_mem_register_(tensor.pin_mem_register_),
146       compression_type_(tensor.compression_type_),
147       tensor_name_(tensor.tensor_name_),
148       device_info_(tensor.device_info_) {}
149 
Tensor(const Tensor & tensor,TypeId data_type)150 Tensor::Tensor(const Tensor &tensor, TypeId data_type)
151     : BaseTensor(tensor, data_type),
152       init_flag_(tensor.init_flag_),
153       need_release_device_mem_(tensor.need_release_device_mem_),
154       cache_enable_(tensor.cache_enable_),
155       cache_tensor_ptr_(tensor.cache_tensor_ptr_),
156       hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
157       pin_mem_register_(tensor.pin_mem_register_),
158       compression_type_(tensor.compression_type_),
159       tensor_name_(tensor.tensor_name_),
160       device_info_(tensor.device_info_) {}
161 
Tensor(const BaseTensor & tensor,TypeId data_type)162 Tensor::Tensor(const BaseTensor &tensor, TypeId data_type) : BaseTensor(tensor, data_type) {}
163 
Tensor(const BaseTensor & base_tensor)164 Tensor::Tensor(const BaseTensor &base_tensor) : BaseTensor(base_tensor) {}
165 
operator =(const Tensor & tensor)166 Tensor &Tensor::operator=(const Tensor &tensor) {
167   if (this == &tensor) {
168     return *this;
169   }
170   BaseTensor::operator=(tensor);
171   init_flag_ = tensor.init_flag_;
172   need_release_device_mem_ = tensor.need_release_device_mem_;
173   cache_enable_ = tensor.cache_enable_;
174   cache_tensor_ptr_ = tensor.cache_tensor_ptr_;
175   hashmap_tensor_ptr_ = tensor.hashmap_tensor_ptr_;
176   pin_mem_register_ = tensor.pin_mem_register_;
177   compression_type_ = tensor.compression_type_;
178   tensor_name_ = tensor.tensor_name_;
179   adapter_flag_ = tensor.adapter_flag_;
180   cast_dtype_ = tensor.cast_dtype_;
181   graph_output_ = tensor.graph_output_;
182   quant_params_ = tensor.quant_params_;
183   updated_by_device_ = tensor.updated_by_device_;
184   device_info_ = tensor.device_info_;
185   return *this;
186 }
187 
Tensor(TypeId data_type,const ShapeVector & shape,TensorDataPtr data)188 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) : BaseTensor(data_type, shape, data) {}
189 
Tensor(TypeId data_type,const ShapeVector & shape)190 Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : BaseTensor(data_type, shape) {}
191 
Tensor(TypeId data_type,const ShapeVector & shape,void * data,size_t data_len)192 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
193     : BaseTensor(data_type, shape, data, data_len) {}
194 
Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)195 Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
196     : BaseTensor(data_type, shape, data, src_data_type) {}
197 
Tensor(const std::vector<int64_t> & input,const TypePtr & data_type)198 Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
199 
Tensor(const std::vector<int32_t> & input,const TypePtr & data_type)200 Tensor::Tensor(const std::vector<int32_t> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
201 
Tensor(const std::vector<double> & input,const TypePtr & data_type)202 Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
203 
Tensor(const std::vector<float> & input,const TypePtr & data_type)204 Tensor::Tensor(const std::vector<float> &input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
205 
Tensor(int64_t input,const TypePtr & data_type)206 Tensor::Tensor(int64_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
207 
Tensor(int32_t input,const TypePtr & data_type)208 Tensor::Tensor(int32_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
209 
Tensor(int16_t input,const TypePtr & data_type)210 Tensor::Tensor(int16_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
211 
Tensor(int8_t input,const TypePtr & data_type)212 Tensor::Tensor(int8_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
213 
Tensor(double input,const TypePtr & data_type)214 Tensor::Tensor(double input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
215 
Tensor(float input,const TypePtr & data_type)216 Tensor::Tensor(float input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
217 
Tensor(float16 input,const TypePtr & data_type)218 Tensor::Tensor(float16 input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
219 #ifndef KERNEL_EXECUTOR_ANDROID
Tensor(bfloat16 input,const TypePtr & data_type)220 Tensor::Tensor(bfloat16 input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
221 #endif
Tensor(uint64_t input,const TypePtr & data_type)222 Tensor::Tensor(uint64_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
223 
Tensor(uint32_t input,const TypePtr & data_type)224 Tensor::Tensor(uint32_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
225 
Tensor(uint16_t input,const TypePtr & data_type)226 Tensor::Tensor(uint16_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
227 
Tensor(uint8_t input,const TypePtr & data_type)228 Tensor::Tensor(uint8_t input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
229 
Tensor(bool input,const TypePtr & data_type)230 Tensor::Tensor(bool input, const TypePtr &data_type) : BaseTensor(input, data_type) {}
231 
Tensor(TypeId data_type,size_t data_size)232 Tensor::Tensor(TypeId data_type, size_t data_size) : BaseTensor(data_type, data_size) {}
233 
Tensor(TypeId origin_data_type,const ShapeVector & shape,size_t compression_data_size,TensorCompressionType compression_type)234 Tensor::Tensor(TypeId origin_data_type, const ShapeVector &shape, size_t compression_data_size,
235                TensorCompressionType compression_type)
236     : BaseTensor(origin_data_type, shape, compression_data_size, compression_type) {
237   compression_type_ = compression_type;
238 }
239 
~Tensor()240 Tensor::~Tensor() {
241   try {
242     UnPinMemory();
243     pin_mem_register_ = nullptr;
244   } catch (const std::exception &e) {
245     MS_LOG(ERROR) << "Exception when destruct tensor. Error info " << e.what();
246   }
247 }
248 
operator ==(const Tensor & tensor) const249 bool Tensor::operator==(const Tensor &tensor) const {
250   return (&tensor == this || (BaseTensor::operator==(tensor) && data_ == tensor.data_));
251 }
252 
253 // Assign value to this tensor.
AssignValue(const Tensor & tensor)254 Tensor &Tensor::AssignValue(const Tensor &tensor) {
255   if (this != &tensor) {
256     BaseTensor::AssignValue(tensor);
257     device_info_ = tensor.device_info_;
258     need_release_device_mem_ = tensor.need_release_device_mem_;
259 
260     // Need execute callback when update host value of Tensor.
261     ExecuteUpdateValueCallback();
262   }
263   return *this;
264 }
265 
ToAbstract()266 abstract::AbstractBasePtr Tensor::ToAbstract() {
267   auto abs_tensor = BaseTensor::ToAbstract()->cast<abstract::AbstractTensorPtr>();
268   if (is_adapter()) {
269     abs_tensor->set_is_adapter(true);
270   }
271   return abs_tensor;
272 }
273 
data_sync(bool need_wait) const274 void Tensor::data_sync(bool need_wait) const { BaseTensor::data_sync(need_wait); }
275 
ExecuteUpdateValueCallback() const276 void Tensor::ExecuteUpdateValueCallback() const {
277   if (update_value_callback_ != nullptr) {
278     update_value_callback_(this);
279   }
280 }
281 
SetDeviceInfo(const std::string & format,const TypePtr & data_type,const std::string & host_format)282 void Tensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type, const std::string &host_format) {
283   DeviceInfo info(format, data_type, host_format);
284   set_device_info(info);
285 }
286 
data_sync_directly(const DeviceSync * const device_sync,bool need_wait) const287 void Tensor::data_sync_directly(const DeviceSync *const device_sync, bool need_wait) const {
288   if (need_wait) {
289     ExecuteLazyTask();
290   }
291   if (device_sync == nullptr) {
292     return;
293   }
294   MS_EXCEPTION_IF_NULL(data_);
295   if (data_->is_sub_data()) {
296     return;
297   }
298 
299   std::vector<size_t> shape_tmp;
300   (void)std::transform(shape().begin(), shape().end(), std::back_inserter(shape_tmp), IntToSize);
301   auto size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(data_type());
302   if (size != 0 && !device_sync->SyncDeviceToHost(shape(), size, data_type(), data_c())) {
303     MS_LOG(INTERNAL_EXCEPTION) << "SyncDeviceToHost failed.";
304   }
305   sync_status_ = kNeedSyncHostToDevice;
306 }
307 
Offload(const std::string & file_path)308 bool Tensor::Offload(const std::string &file_path) {
309   if (file_path.empty()) {
310     return false;
311   }
312 
313   auto fs = mindspore::system::Env::GetFileSystem();
314   MS_EXCEPTION_IF_NULL(fs);
315   MS_EXCEPTION_IF_NULL(data_);
316   auto data_ptr = data_->data();
317   auto file = fs->CreateWriteFile(file_path);
318   MS_EXCEPTION_IF_NULL(file);
319   TempFileManager::GetInstance().Register(file_path);
320   bool success = file->PWrite(data_ptr, LongToSize(data_->nbytes()), 0);
321   if (!file->Close()) {
322     MS_LOG(WARNING) << "Close tensor file: " << file_path << " failed!";
323   }
324   if (!success) {
325     MS_LOG(WARNING) << "Tensor write data to file: " << file_path << " failed!";
326     return false;
327   }
328 
329   if (file_path == GetOffloadFilePath()) {
330     data_->set_file_path("");
331   }
332 
333   data_ = tensor::MakeTensorData(data_type_, shape_);
334   MS_EXCEPTION_IF_NULL(data_);
335   data_->set_file_path(file_path);
336   return true;
337 }
338 
GetOffloadFilePath() const339 const std::string Tensor::GetOffloadFilePath() const {
340   if (data_ == nullptr) {
341     return "";
342   }
343   return data_->file_path();
344 }
345 
GetChunkOffset() const346 std::pair<void *, size_t> Tensor::GetChunkOffset() const {
347   // Get sub-data.
348   auto sub_data = std::dynamic_pointer_cast<TensorSubData>(data_ptr());
349   if (sub_data == nullptr) {
350     return {nullptr, 0};
351   }
352   // Get owner tensor from sub-data.
353   auto owner_tensor = sub_data->GetOwner();
354   MS_EXCEPTION_IF_NULL(owner_tensor);
355   return {owner_tensor->data_c(), sub_data->data_offset()};
356 }
357 
GroupingTensors(const TensorPtrList & tensors,size_t fusion_size)358 static std::map<TypeId, std::vector<TensorChunk>> GroupingTensors(const TensorPtrList &tensors, size_t fusion_size) {
359   // Use std::map to keep order by type id.
360   std::map<TypeId, std::vector<TensorChunk>> group_info;
361   for (auto &tensor : tensors) {
362     MS_EXCEPTION_IF_NULL(tensor);
363     auto tensor_bytes = static_cast<size_t>(tensor->data().nbytes());
364     if ((fusion_size != 0) && (tensor_bytes > fusion_size)) {
365       MS_LOG(EXCEPTION) << "Fusion size " << fusion_size << " is too small for a tensor size " << tensor_bytes << ".";
366     }
367     auto &chunks = group_info[normalize_type(tensor->data_type())];
368     if (chunks.empty()) {
369       (void)chunks.emplace_back();
370     }
371     if ((fusion_size != 0) && (chunks.back().bytes + tensor_bytes > fusion_size)) {
372       (void)chunks.emplace_back();
373     }
374     auto &chunk = chunks.back();
375     chunk.size += tensor->DataSize();
376     chunk.bytes += tensor_bytes;
377     (void)chunk.tensors.emplace_back(tensor);
378   }
379   return group_info;
380 }
381 
FlattenTensors(const TensorPtrList & tensors,size_t fusion_size)382 TensorPtrList Tensor::FlattenTensors(const TensorPtrList &tensors, size_t fusion_size) {
383   // Result tensor list.
384   TensorPtrList result_list;
385   // Grouping tensors by data type and fusion size.
386   auto group_info = GroupingTensors(tensors, fusion_size);
387   // Create chunk tensors and copy data to them.
388   for (auto &type_group : group_info) {
389     auto chunk_dtype = normalize_type(type_group.first);
390     for (auto &chunk : type_group.second) {
391       // Create chunk thensor as a lazy initialized tensor, the tensor data
392       // will be allocated when we begin to copy small tensors data into it.
393       auto chunk_tensor = std::make_shared<Tensor>(chunk_dtype, chunk.size);
394       // Reset and copy tensors data.
395       size_t offset = 0;
396       for (auto &tensor : chunk.tensors) {
397         auto sub_data = MakeTensorSubData(chunk_tensor, offset, tensor->data_ptr());
398         offset += static_cast<size_t>(sub_data->nbytes());
399         tensor->set_data(sub_data);
400       }
401       // Save chunk tensor to result list.
402       (void)result_list.emplace_back(std::move(chunk_tensor));
403     }
404   }
405   return result_list;
406 }
407 
IsFlattened(const TensorPtrList & tensors)408 bool Tensor::IsFlattened(const TensorPtrList &tensors) {
409   // Tensor data is flattened if all tensors data are TensorSubData.
410   return std::all_of(tensors.begin(), tensors.end(), [](const TensorPtr &tensor) {
411     MS_EXCEPTION_IF_NULL(tensor);
412     auto data_ptr = tensor->data_ptr().get();
413     return dynamic_cast<TensorSubData *>(data_ptr) != nullptr;
414   });
415 }
416 
GetFlattenedTensors(const TensorPtrList & tensors)417 TensorPtrList Tensor::GetFlattenedTensors(const TensorPtrList &tensors) {
418   // Use std::map to keep order by type id.
419   std::map<TypeId, OrderedSet<TensorPtr>> chunk_map;
420   for (auto &tensor : tensors) {
421     // Get sub-data.
422     auto sub_data = std::dynamic_pointer_cast<TensorSubData>(tensor->data_ptr());
423     if (sub_data == nullptr) {
424       MS_LOG(WARNING) << "Tensors are not flattened.";
425       return {};
426     }
427     // Get owner tensor from sub-data.
428     auto owner_tensor = std::dynamic_pointer_cast<Tensor>(sub_data->GetOwner());
429     MS_EXCEPTION_IF_NULL(owner_tensor);
430     // Add as chunk tensor by its data type.
431     auto chunk_dtype = normalize_type(tensor->data_type());
432     chunk_map[chunk_dtype].add(owner_tensor);
433   }
434   // Generate result tensor list.
435   TensorPtrList result_tensors;
436   for (auto &entry : chunk_map) {
437     auto &chunk_tensors = entry.second;
438     (void)result_tensors.insert(result_tensors.end(), chunk_tensors.begin(), chunk_tensors.end());
439   }
440   return result_tensors;
441 }
442 
CheckStub()443 bool Tensor::CheckStub() {
444 #if defined(WITH_BACKEND)
445   return false;
446 #else
447   auto context_ptr = MsContext::GetInstance();
448   MS_EXCEPTION_IF_NULL(context_ptr);
449   std::string backend_name = context_ptr->backend_policy();
450   if (backend_name == "vm") {
451     return false;
452   }
453   return true;
454 #endif
455 }
456 
GetFusionSize(const TensorPtrList & flat_tensors)457 size_t Tensor::GetFusionSize(const TensorPtrList &flat_tensors) {
458   size_t fusion_size = 0;
459   std::map<TypeId, size_t> type_groups;
460   for (auto &tensor : flat_tensors) {
461     MS_EXCEPTION_IF_NULL(tensor);
462     auto tensor_bytes = static_cast<size_t>(tensor->data().nbytes());
463     if (tensor_bytes > fusion_size) {
464       fusion_size = tensor_bytes;
465     }
466     ++type_groups[tensor->data_type()];
467   }
468   const bool only_one_chunk_for_each_type =
469     std::all_of(type_groups.begin(), type_groups.end(), [](auto const &e) { return e.second == 1; });
470   if (only_one_chunk_for_each_type) {
471     return 0;
472   }
473   return fusion_size;
474 }
475 
is_persistent_data() const476 bool Tensor::is_persistent_data() const { return this->data().is_persistent_data(); }
477 
PinMemory(PinnedMemRegister * pin_mem_register)478 void Tensor::PinMemory(PinnedMemRegister *pin_mem_register) {
479   if (pin_mem_register == nullptr) {
480     return;
481   }
482   pin_mem_register_ = pin_mem_register;
483   pin_mem_register_->RegisterPinnedMem(data_c(), Size());
484 }
485 
UnPinMemory()486 void Tensor::UnPinMemory() {
487   if (pin_mem_register_ == nullptr) {
488     return;
489   }
490   pin_mem_register_->UnRegisterPinnedMem(data_c());
491 }
492 
CSRTensor(const TensorPtr indptr,const TensorPtr indices,const TensorPtr values,const ShapeVector & shape)493 CSRTensor::CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
494     : MetaSparseTensor(values->data_type(), shape), indptr_(indptr), indices_(indices), values_(values) {}
495 
ToString() const496 std::string CSRTensor::ToString() const {
497   std::ostringstream buf;
498   MS_EXCEPTION_IF_NULL(values_);
499   MS_EXCEPTION_IF_NULL(indices_);
500   MS_EXCEPTION_IF_NULL(indptr_);
501   auto dtype = values_->Dtype();
502   values_->data_sync(true);
503   indices_->data_sync(true);
504   indptr_->data_sync(true);
505   buf << "CSRTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ", indptr=";
506   buf << indptr_->ToString() << ", indices=" << indices_->ToString() << ", values=";
507   buf << values_->ToString() << ")";
508   return buf.str();
509 }
510 
ToAbstract()511 abstract::AbstractBasePtr CSRTensor::ToAbstract() {
512   auto dtype = values_->Dtype();
513   if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
514     MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
515   }
516 
517   auto indptr = indptr_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
518   auto indices = indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
519   auto values = values_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
520   std::vector<abstract::AbstractBasePtr> abstract_shape;
521   (void)std::transform(
522     shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
523     [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
524   auto shape = std::make_shared<abstract::AbstractTuple>(abstract_shape);
525   AbstractBasePtrList element_list{indptr, indices, values, shape};
526 
527   return std::make_shared<abstract::AbstractCSRTensor>(element_list);
528 }
529 
GetSizeAt(size_t index) const530 const size_t CSRTensor::GetSizeAt(size_t index) const {
531   if (index == kIndptrIdx) {
532     MS_EXCEPTION_IF_NULL(indptr_);
533     return indptr_->data().nbytes();
534   } else if (index == kIndicesIdx) {
535     MS_EXCEPTION_IF_NULL(indices_);
536     return indices_->data().nbytes();
537   } else if (index == kValuesIdx) {
538     MS_EXCEPTION_IF_NULL(values_);
539     return values_->data().nbytes();
540   } else if (index >= kIndicesIdx && index < kShapeIdx + shape().size()) {
541     return sizeof(int64_t);
542   }
543   MS_LOG(EXCEPTION) << "Invalid index: " << index << " for CSRTensor: " << ToString();
544 }
545 
GetTensorAt(size_t index) const546 TensorPtr CSRTensor::GetTensorAt(size_t index) const {
547   if (index == kIndptrIdx) {
548     MS_EXCEPTION_IF_NULL(indptr_);
549     return indptr_;
550   } else if (index == kIndicesIdx) {
551     MS_EXCEPTION_IF_NULL(indices_);
552     return indices_;
553   } else if (index == kValuesIdx) {
554     MS_EXCEPTION_IF_NULL(values_);
555     return values_;
556   } else if (index >= kShapeIdx && index < kShapeIdx + shape().size()) {
557     return std::make_shared<tensor::Tensor>(shape_[index - kShapeIdx], TypeIdToType(kNumberTypeInt64));
558   }
559   MS_LOG(EXCEPTION) << "Invalid index: " << index << " for CSRTensor: " << ToString();
560 }
561 
GetTensorAt(size_t index) const562 TensorPtr COOTensor::GetTensorAt(size_t index) const {
563   if (index == kIndicesIdx) {
564     MS_EXCEPTION_IF_NULL(indices_);
565     return indices_;
566   } else if (index == kValuesIdx) {
567     MS_EXCEPTION_IF_NULL(values_);
568     return values_;
569   } else if (index >= kShapeIdx && index < kShapeIdx + shape().size()) {
570     return std::make_shared<tensor::Tensor>(shape_[index - kShapeIdx], TypeIdToType(kNumberTypeInt64));
571   }
572   MS_LOG(EXCEPTION) << "Invalid index: " << index << " for COOTensor: " << ToString();
573 }
574 
ToString() const575 std::string COOTensor::ToString() const {
576   std::ostringstream buf;
577   MS_EXCEPTION_IF_NULL(indices_);
578   MS_EXCEPTION_IF_NULL(values_);
579   indices_->data_sync(true);
580   values_->data_sync(true);
581   auto dtype = values_->Dtype();
582   buf << "COOTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
583       << ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")";
584   return buf.str();
585 }
586 
ToAbstract()587 abstract::AbstractBasePtr COOTensor::ToAbstract() {
588   MS_EXCEPTION_IF_NULL(values_);
589   auto dtype = values_->Dtype();
590   if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
591     MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
592   }
593   MS_EXCEPTION_IF_NULL(indices_);
594   MS_EXCEPTION_IF_NULL(indices_->ToAbstract());
595   MS_EXCEPTION_IF_NULL(values_->ToAbstract());
596   auto indices = indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
597   auto values = values_->ToAbstract()->cast<abstract::AbstractTensorPtr>();
598   std::vector<abstract::AbstractBasePtr> abstract_shape;
599   (void)std::transform(
600     shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
601     [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
602   auto shape = std::make_shared<abstract::AbstractTuple>(abstract_shape);
603   AbstractBasePtrList element_list{indices, values, shape};
604 
605   return std::make_shared<abstract::AbstractCOOTensor>(element_list);
606 }
607 
ToString() const608 std::string RowTensor::ToString() const {
609   std::ostringstream buf;
610   MS_EXCEPTION_IF_NULL(indices_);
611   MS_EXCEPTION_IF_NULL(values_);
612   auto dtype = values_->Dtype();
613   buf << "RowTensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString()
614       << ", indices=" << indices_->ToString() << ", values=" << values_->ToString() << ")";
615   return buf.str();
616 }
617 
ToAbstract()618 abstract::AbstractBasePtr RowTensor::ToAbstract() {
619   auto dtype = values_->Dtype();
620   if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) {
621     MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << ".";
622   }
623   auto abs_sparse_tensor = std::make_shared<abstract::AbstractRowTensor>(dtype, shape_);
624   MS_EXCEPTION_IF_NULL(indices_);
625   MS_EXCEPTION_IF_NULL(indices_->ToAbstract());
626   MS_EXCEPTION_IF_NULL(values_->ToAbstract());
627   abs_sparse_tensor->set_indices(indices_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
628   abs_sparse_tensor->set_values(values_->ToAbstract()->cast<abstract::AbstractTensorPtr>());
629 
630   std::vector<abstract::AbstractBasePtr> abstract_shape;
631   (void)std::transform(
632     shape_.begin(), shape_.end(), std::back_inserter(abstract_shape),
633     [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
634   abs_sparse_tensor->set_dense_shape(std::make_shared<abstract::AbstractTuple>(abstract_shape));
635 
636   return abs_sparse_tensor;
637 }
638 }  // namespace tensor
639 }  // namespace mindspore
640