• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_TENSOR_H_
18 #define MINDSPORE_CORE_IR_TENSOR_H_
19 
20 #include <future>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <numeric>
25 #include <mutex>
26 #include <condition_variable>
27 #include <utility>
28 #include "ir/device_sync.h"
29 #include "ir/base_tensor.h"
30 #include "utils/log_adapter.h"
31 #include "base/float16.h"
32 #include "base/bfloat16.h"
33 #include "utils/shape_utils.h"
34 #include "utils/ms_exception.h"
35 #include "ir/device_event.h"
36 #include "utils/os.h"
37 #include "ir/quantization_param.h"
38 #include "ir/meta_grad_data.h"
39 #include "ir/tensor_data.h"
40 
41 // brief mindspore namespace.
42 //
43 // mindspore namespace is the top level namespace of MindSpore project.
44 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
45 namespace mindspore {
46 // Pinned memory register interface.
47 class MS_CORE_API PinnedMemRegister {
48  public:
49   /// \brief Default constructor for register.
50   PinnedMemRegister() = default;
51 
52   /// \brief Virtual destructor for register.
53   virtual ~PinnedMemRegister() = default;
54 
55   /// \brief Register pinned memory.
56   ///
57   /// \param[in] addr The host address to pin.
58   /// \param[in] size The host data size.
59   /// \return Void.
60   virtual void RegisterPinnedMem(void *addr, size_t size) = 0;
61 
62   /// \brief UnRegister pinned memory.
63   ///
64   /// \param[in] addr The host address to unpin.
65   /// \return Void.
66   virtual void UnRegisterPinnedMem(void *addr) = 0;
67 };
68 
69 // A sub namespace in ME to support tensor related definition.
70 namespace tensor {
71 class Tensor;
72 using TensorPtr = std::shared_ptr<Tensor>;
73 using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
74 
75 template <typename T>
76 class FutureData {
77  public:
FutureData(std::shared_ptr<T> data,std::exception_ptr e_ptr)78   FutureData(std::shared_ptr<T> data, std::exception_ptr e_ptr) : data_(std::move(data)), e_ptr_(std::move(e_ptr)) {}
~FutureData()79   virtual ~FutureData() {}
80 
GetData()81   virtual std::shared_ptr<T> GetData() const { return data_; }
GetException()82   const std::exception_ptr &GetException() const { return e_ptr_; }
83 
84  private:
85   std::shared_ptr<T> data_;
86   std::exception_ptr e_ptr_;
87 };
88 
89 template <typename T>
90 class FutureBase {
91  public:
FutureBase(std::future<std::shared_ptr<tensor::FutureData<T>>> future)92   explicit FutureBase(std::future<std::shared_ptr<tensor::FutureData<T>>> future) : future_(std::move(future)) {}
~FutureBase()93   virtual ~FutureBase() {}
94   virtual std::shared_ptr<T> Get() = 0;
95 
96  protected:
97   std::future<std::shared_ptr<tensor::FutureData<T>>> future_;
98   std::shared_ptr<tensor::FutureData<T>> future_data_;
99 };
100 // brief Device info of Tensor
101 //
102 // Includes the format, data type and host format of a tensor.
103 struct DeviceInfo {
104   explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr,
105                       std::string host_format = "DefaultFormat", int32_t device_id = 0)
format_DeviceInfo106       : format_(std::move(format)),
107         data_type_(std::move(data_type)),
108         host_format_(std::move(host_format)),
109         device_id_(device_id) {}
110   std::string format_ = "DefaultFormat";
111   TypePtr data_type_ = nullptr;
112   std::string host_format_ = "DefaultFormat";
113   int32_t device_id_ = 0;
114 };
115 
116 // Tensor entity class
117 class MS_CORE_API Tensor : public BaseTensor {
118  public:
119   Tensor() = default;
120 
121   /// \brief Create tensor from another tensor, data is shared.
122   ///
123   /// \param[in] tensor [Tensor] The input tensor.
124   explicit Tensor(const Tensor &tensor);
125   /// \brief Create tensor with given data type from another tensor.
126   ///
127   /// \param[in] tensor [Tensor] The input tensor.
128   /// \param[in] data_type [TypeId] The new tensor data type.
129   Tensor(const Tensor &tensor, TypeId data_type);
130 
131   /// \brief Create tensor with base tensor.
132   ///
133   /// \param[in] tensor [Tensor] The input base tensor.
134   explicit Tensor(const BaseTensor &tensor);
135 
136   /// \brief Create tensor with given data type from another tensor.
137   ///
138   /// \param[in] tensor [Tensor] The input tensor.
139   /// \param[in] data_type [TypeId] The new tensor data type.
140   Tensor(const BaseTensor &tensor, TypeId data_type);
141 
142   /// \brief Create tensor with the given shared tensor data.
143   ///
144   /// \param[in] data_type [TypeId] Data type of the tensor.
145   /// \param[in] shape The shape represented by ShapeVector of the tensor.
146   /// \param[in] data The shared tensor data.
147   Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data);
148 
149   /// \brief Create a lazy allocated tensor.
150   ///
151   /// \param[in] data_type [TypeId] Data type of the tensor.
152   /// \param[in] shape The shape represented by ShapeVector of the tensor.
153   Tensor(TypeId data_type, const ShapeVector &shape);
154 
155   /// \brief Create a tensor with input data buffer.
156   ///
157   /// \param[in] data_type [TypeId] Data type of the tensor.
158   /// \param[in] shape The shape represented by ShapeVector of the tensor.
159   /// \param[in] data The input data to be copied into tensor.
160   /// \param[in] data_len The length of data in bytes.
161   Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);
162 
163   /// \brief Create a tensor with input data buffer and given source data type.
164   ///
165   /// \param[in] data_type [TypeId] Data type of the tensor.
166   /// \param[in] shape The shape represented by ShapeVector of the tensor.
167   /// \param[in] data The input data to be copied into tensor.
168   /// \param[in] src_data_type The source data type.
169   Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type);
170 
171   /// \brief Create 1 dimension tensor from an int vector.
172   ///
173   /// \param[in] input [std::vector<int64_t>] the data for tensor.
174   /// \param[in] data_type [TypeId] data type.
175   explicit Tensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr);
176 
177   /// \brief Create 1 dimension tensor from an int vector.
178   ///
179   /// \param[in] input [std::vector<int32_t>] the data for tensor.
180   /// \param[in] data_type [TypeId] data type.
181   explicit Tensor(const std::vector<int32_t> &input, const TypePtr &data_type = nullptr);
182 
183   /// \brief Create 1 dimension tensor from a float vector.
184   ///
185   /// \param[in] input [std::vector<double>] the data for tensor.
186   /// \param[in] data_type [TypeId] data type.
187   explicit Tensor(const std::vector<double> &input, const TypePtr &data_type = nullptr);
188 
189   /// \brief Create 1 dimension tensor from a float vector.
190   ///
191   /// \param[in] input [std::vector<float>] the data for tensor.
192   /// \param[in] data_type [TypeId] data type.
193   explicit Tensor(const std::vector<float> &input, const TypePtr &data_type = nullptr);
194 
195   /// \brief Create 0 dimension tensor from an int64_t scalar.
196   ///
197   /// \param[in] input [int64] the data for tensor.
198   /// \param[in] data_type [TypeId] data type.
199   explicit Tensor(int64_t input, const TypePtr &data_type = nullptr);
200 
201   /// \brief Create 0 dimension tensor from an int32_t scalar.
202   ///
203   /// \param[in] input [int32] the data for tensor.
204   /// \param[in] data_type [TypeId] data type.
205   explicit Tensor(int32_t input, const TypePtr &data_type = nullptr);
206 
207   /// \brief Create 0 dimension tensor from an int16_t scalar.
208   ///
209   /// \param[in] input [int16] the data for tensor.
210   /// \param[in] data_type [TypeId] data type.
211   explicit Tensor(int16_t input, const TypePtr &data_type = nullptr);
212 
213   /// \brief Create 0 dimension tensor from an int8_t scalar.
214   ///
215   /// \param[in] input [int8] the data for tensor.
216   /// \param[in] data_type [TypeId] data type.
217   explicit Tensor(int8_t input, const TypePtr &data_type = nullptr);
218 
219   /// \brief Create 0 dimension tensor from a double scalar.
220   ///
221   /// \param[in] input [double] the data for tensor.
222   /// \param[in] data_type [TypeId] data type.
223   explicit Tensor(double input, const TypePtr &data_type = nullptr);
224 
225   /// \brief Create 0 dimension tensor from a float scalar.
226   ///
227   /// \param[in] input [float] the data for tensor.
228   /// \param[in] data_type [TypeId] data type.
229   explicit Tensor(float input, const TypePtr &data_type = nullptr);
230 
231   /// \brief Create 0 dimension tensor from a float16 scalar.
232   ///
233   /// \param[in] input [float16] the data for tensor.
234   /// \param[in] data_type [TypeId] data type.
235   explicit Tensor(float16 input, const TypePtr &data_type = nullptr);
236 
237   /// \brief Create 0 dimension tensor from a bfloat16 scalar.
238   ///
239   /// \param[in] input [bfloat16] the data for tensor.
240   /// \param[in] data_type [TypeId] data type.
241   explicit Tensor(bfloat16 input, const TypePtr &data_type = nullptr);
242 
243   /// \brief Create 0 dimension tensor from a uint64 scalar.
244   ///
245   /// \param[in] input [uint64] the data for tensor.
246   /// \param[in] data_type [TypeId] data type.
247   explicit Tensor(uint64_t input, const TypePtr &data_type = nullptr);
248 
249   /// \brief Create 0 dimension tensor from a uint32 scalar.
250   ///
251   /// \param[in] input [uint32] the data for tensor.
252   /// \param[in] data_type [TypeId] data type.
253   explicit Tensor(uint32_t input, const TypePtr &data_type = nullptr);
254 
255   /// \brief Create 0 dimension tensor from a uint16 scalar.
256   ///
257   /// \param[in] input [uint16] the data for tensor.
258   /// \param[in] data_type [TypeId] data type.
259   explicit Tensor(uint16_t input, const TypePtr &data_type = nullptr);
260 
261   /// \brief Create 0 dimension tensor from a uint8 scalar.
262   ///
263   /// \param[in] input [uint8] the data for tensor.
264   /// \param[in] data_type [TypeId] data type.
265   explicit Tensor(uint8_t input, const TypePtr &data_type = nullptr);
266 
267   /// \brief Create 0 dimension tensor from a bool scalar.
268   ///
269   /// \param[in] input [bool] the data for tensor.
270   /// \param[in] data_type [TypeId] data type.
271   explicit Tensor(bool input, const TypePtr &data_type = nullptr);
272 
273   /// \brief Create a chunk tensor with the given data size.
274   ///
275   /// \param[in] data_type [TypeId] Data type of the tensor.
276   /// \param[in] data_size The tensor chunk data size in number of elements.
277   Tensor(TypeId data_type, size_t data_size);
278 
279   /// \brief Create a Tensor which shape and size may be inconsistent, such as Tensor with compression data.
280   ///
281   /// \param[in] origin_data_type [TypeId] Data type of the origin tensor.
282   /// \param[in] shape The shape represented by ShapeVector of the tensor.
283   /// \param[in] compression_data_size The compression data buffer size.
284   /// \param[in] TensorCompressionType The tensor compression type.
285   Tensor(TypeId origin_data_type, const ShapeVector &shape, size_t compression_data_size,
286          TensorCompressionType compression_type);
287 
288   Tensor &operator=(const Tensor &tensor);
289 
290   /// Destructor of Tensor.
291   ~Tensor() override;
292 
293   MS_DECLARE_PARENT(Tensor, BaseTensor);
294 
295   /// \brief Compare two tensor objects to see if they have same data type, shape and data address.
296   ///
297   /// \param[in] tensor The Tensor object to be compared.
298   /// \return True if having same type, shape and data address, otherwise false.
299   bool operator==(const Tensor &tensor) const;
300 
301   /// \brief Create Abstract for Tensor.
302   ///
303   /// \return Abstract of Tensor.
304   abstract::AbstractBasePtr ToAbstract() override;
305 
306   /// \brief Assign value to this tensor.
307   ///
308   /// \param[in] tensor The input tensor.
309   /// \return Tensor with new value.
310   Tensor &AssignValue(const Tensor &tensor);
311 
312   bool operator==(const Value &other) const override {
313     if (other.isa<Tensor>()) {
314       auto &other_ = static_cast<const Tensor &>(other);
315       return *this == other_;
316     }
317     return false;
318   }
319 
320   /// \brief To synchronize data with the device, you need to wait for the data to be valid.
321   ///
322   void data_sync(bool need_wait = true) const;
323 
324   /// \brief To synchronize data with the device without keeping device address, you need to wait for the data to be
325   /// valid.
326   ///
327   void data_sync_directly(const DeviceSync *const device_sync, bool need_wait = true) const;
328 
329   /// \brief Check if this Tensor is initialized.
330   ///
331   /// \return Whether this Tensor is initialized.
is_init()332   bool is_init() const { return init_flag_; }
333 
334   /// \brief Set the initialization flag of this Tensor.
335   ///
336   /// \param[in] flag Whether this Tensor is initialized.
set_init_flag(bool flag)337   void set_init_flag(bool flag) { init_flag_ = flag; }
338 
339   /// \brief Check whether this Tensor needs to be converted.
340   ///
341   /// \return Whether this Tensor needs to be converted.
is_adapter()342   bool is_adapter() const { return adapter_flag_; }
343 
344   /// \brief Set the adapter flag of this Tensor.
345   ///
346   /// \param[in] flag Whether this Tensor needs to be converted.
set_adapter_flag(bool flag)347   void set_adapter_flag(bool flag) { adapter_flag_ = flag; }
348 
349   /// \brief Check whether to release device memory.
350   ///
351   /// \return Ture if need to release device memory, otherwise false.
need_release_device_mem()352   bool need_release_device_mem() const { return need_release_device_mem_; }
353 
354   /// \brief Set the flag to determine whether the device memory needs to be released.
355   ///
356   /// \param[in] release_device_mem If release_device_mem is ture, the device memory will to be released.
set_need_release_device_mem(bool release_device_mem)357   void set_need_release_device_mem(bool release_device_mem) { need_release_device_mem_ = release_device_mem; }
358 
359   /// \brief Get the cast dtype of this Tensor.
360   ///
361   /// \return The cast dtype of this Tensor.
cast_dtype()362   TypePtr cast_dtype() { return cast_dtype_; }
363 
364   /// \brief Set the cast dtype of this Tensor.
365   ///
366   /// \param[in] dtype The input cast dtype.
367   void set_cast_dtype(const TypePtr &dtype = nullptr) { cast_dtype_ = dtype; }
368 
369   /// \brief Used cache_enable to update the tensor from the cache to the host.
370   ///
371   /// \return True if caching is enabled, otherwise false.
cache_enable()372   bool cache_enable() const { return cache_enable_; }
373 
374   /// \brief Set cache_enable.
375   ///
376   /// \param[in] cache_enable Whether to enable caching.
377   void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; }
378 
379   /// \brief Get the pointer of hashmap tensor.
380   ///
381   /// \return The pointer of hashmap tensor.
hashmap_tensor_ptr()382   std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; }
383 
384   /// \brief Set the pointer of hashmap tensor.
385   ///
386   /// \param[in] hashmap_tensor_ptr The input pointer of hashmap tensor.
387   void set_hashmap_tensor_ptr(const std::shared_ptr<Tensor> &hashmap_tensor_ptr = nullptr) {
388     hashmap_tensor_ptr_ = hashmap_tensor_ptr;
389   }
390 
391   /// \brief Get the pointer of cache tensor.
392   ///
393   /// \return The pointer of cache tensor.
cache_tensor_ptr()394   std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; }
395 
396   /// \brief Set the pointer of cache tensor.
397   ///
398   /// \param[in] cache_tensor_ptr The input pointer of cache tensor.
399   void set_cache_tensor_ptr(const std::shared_ptr<Tensor> &cache_tensor_ptr = nullptr) {
400     cache_tensor_ptr_ = cache_tensor_ptr;
401   }
402 
403   /// \brief Check if this Tensor is the output of graph.
404   ///
405   /// \return Whether this Tensor is the output of graph
IsGraphOutput()406   bool IsGraphOutput() const { return graph_output_; }
407 
408   /// \brief Set whether this Tensor is the output of graph.
SetIsGraphOutput()409   void SetIsGraphOutput() { graph_output_ = true; }
410 
411   /// \brief Get whether this Tensor is updated by the device.
412   ///
413   /// \return Whether this Tensor is updated by the device.
IsUpdatedByDevice()414   bool IsUpdatedByDevice() const { return updated_by_device_; }
415 
416   /// \brief Set whether this Tensor is updated by the device.
SetIsUpdateByDevice()417   void SetIsUpdateByDevice() { updated_by_device_ = true; }
418 
419   /// \brief Get callback need to execute when value is updated of Tensor.
420   ///
421   /// \return The callback need to execute when value is updated of Tensor.
update_value_callback()422   const std::function<void(const Tensor *)> &update_value_callback() const { return update_value_callback_; }
423 
424   /// \brief Set callback need to execute when value is updated of Tensor.
425   ///
426   /// \param[in] update_value_callback The callback need to execute when value is updated of Tensor.
set_update_value_callback(const std::function<void (const Tensor *)> & update_value_callback)427   void set_update_value_callback(const std::function<void(const Tensor *)> &update_value_callback) {
428     update_value_callback_ = update_value_callback;
429   }
430 
431   /// \brief Get the memory chunk pointer and offset if memory chunk for this tensor exists.
432   ///
433   /// \return The memory chunk pointer and offset, nullptr and 0 if no memory chunk exists.
434   std::pair<void *, size_t> GetChunkOffset() const;
435 
436   /// \brief Reset tensors data so that they are using contiguous memory chunks grouped by data type.
437   ///
438   /// \param[in] tensors The tensors to be processed.
439   /// \param[in] fusion_size Maximum memory chunk size in bytes, 0 for unlimited.
440   ///
441   /// \return Tensors that data are pointed to each contiguous memory chunks.
442   static TensorPtrList FlattenTensors(const TensorPtrList &tensors, size_t fusion_size = 0);
443 
444   /// \brief Check if FlattenTensors called for the input tensors.
445   ///
446   /// \param[in] tensors The tensors to be checked.
447   ///
448   /// \return True if FlattenTensors called for input tensors, false otherwise.
449   static bool IsFlattened(const TensorPtrList &tensors);
450 
451   /// \brief Get tensors for each contiguous memory chunks used by the input tensors.
452   ///
453   /// \param[in] tensors The input tensors.
454   ///
455   /// \return Tensors that data are pointed to each contiguous memory chunks, empty if failed.
456   static TensorPtrList GetFlattenedTensors(const TensorPtrList &tensors);
457 
458   /// \brief Get tensors stub flag.
459   ///
460   /// \param[in] none.
461   ///
462   /// \return If compile with backend, return false, else return true.
463   static bool CheckStub();
464 
465   /// \brief Get the fusion size for the given flat tensors.
466   ///
467   /// \param[in] flat_tensors The input flat tensors.
468   ///
469   /// \return fusion size for the given flat tensors.
470   static size_t GetFusionSize(const TensorPtrList &flat_tensors);
471 
472   /// \brief Get the tensor compression type.
473   ///
474   /// \return tensor compression type.
compression_type()475   TensorCompressionType compression_type() const { return compression_type_; }
476 
477   /// \brief If tensor use persistent tensor data.
478   ///
479   /// \return if use persistent tenor data.
480   bool is_persistent_data() const;
481 
482   /// \brief Set tensor name.
483   ///
484   /// \param[in] tensor_name The tensor name.
set_name(const std::string & tensor_name)485   void set_name(const std::string &tensor_name) { tensor_name_ = tensor_name; }
486 
487   /// \brief Get the tensor name.
488   ///
489   /// \return tensor name.
name()490   const std::string &name() const { return tensor_name_; }
491 
492   /// \brief Set tensor quant param.
493   ///
494   /// \param[in] quant_param The tensor quant param.
set_quant_param(const std::vector<std::shared_ptr<QuantizationParam>> & quant_params)495   void set_quant_param(const std::vector<std::shared_ptr<QuantizationParam>> &quant_params) {
496     quant_params_.assign(quant_params.begin(), quant_params.end());
497   }
498 
499   /// \brief Get the tensor quant param.
500   ///
501   /// \return tensor quant param.
quant_params()502   const std::vector<std::shared_ptr<QuantizationParam>> &quant_params() const { return quant_params_; }
503 
504   /// \brief Offload tensor to file.
505   ///
506   /// \return offload tensor success.
507   bool Offload(const std::string &file_path);
508 
509   /// \brief Get tensor offload file path.
510   ///
511   /// \return offload file path, or empty string if tensor has not offload.
512   const std::string GetOffloadFilePath() const;
513 
514   /// \brief pin tensor memory.
515   ///
516   /// \param[in] register to pin tensor data.
517   void PinMemory(PinnedMemRegister *pin_mem_register);
518 
519   /// \brief unpin tensor memory.
520   void UnPinMemory();
521 
522   /// \brief Get tensor's device info.
523   ///
524   /// \return The device info of this tensor.
device_info()525   DeviceInfo device_info() const { return device_info_; }
526 
527   /// \brief Set tensor's device info.
528   ///
529   /// \param[in] device_info The tensor's device info.
set_device_info(const DeviceInfo & device_info)530   void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
531 
532   /// \brief Set tensor's device info.
533   ///
534   /// \param[in] format The input format.
535   /// \param[in] data_type The input data type.
536   /// \param[in] host_format The input host format.
537   void SetDeviceInfo(const std::string &format, const TypePtr &data_type,
538                      const std::string &host_format = "DefaultFormat");
539 
set_copy_done_flag(bool flag)540   void set_copy_done_flag(bool flag) { copy_done_flag_ = flag; }
get_copy_done_flag()541   bool get_copy_done_flag() const { return copy_done_flag_; }
542   bool copy_done_flag_{false};
543 
544  private:
545   // Really execute callback function when host value is updated of Tensor.
546   void ExecuteUpdateValueCallback() const;
547 
548   bool init_flag_{false};
549   bool adapter_flag_{false};
550   bool graph_output_{false};
551   bool updated_by_device_{false};
552   // Release device address of graph output tensor or not.
553   bool need_release_device_mem_{false};
554   bool cache_enable_{false};
555   std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
556   std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
557   TypePtr cast_dtype_{nullptr};
558   std::function<void(const Tensor *)> update_value_callback_{nullptr};
559   PinnedMemRegister *pin_mem_register_{nullptr};
560   TensorCompressionType compression_type_{kNoCompression};
561   std::vector<std::shared_ptr<QuantizationParam>> quant_params_;
562   std::string tensor_name_;
563   // brief Device info of Tensor
564   //
565   // Includes the format and data type of a tensor on device.
566   DeviceInfo device_info_;
567 };
568 
569 // CSRTensor entity class
570 class MS_CORE_API CSRTensor : public MetaSparseTensor {
571  public:
572   abstract::AbstractBasePtr ToAbstract() override;
573 
574   /// \brief Create CSRTensor with given data type from another tensor.
575   ///
576   /// \param[in] indptr [Tensor] The indices pointer.
577   /// \param[in] indices [Tensor] The indices.
578   /// \param[in] values [Tensor] The values.
579   /// \param[in] shape The shape represented by ShapeVector of the CSRensor.
580   CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape);
581 
582   /// Destructor of CSRTensor.
583   ~CSRTensor() override = default;
584 
MS_DECLARE_PARENT(CSRTensor,MetaSparseTensor)585   MS_DECLARE_PARENT(CSRTensor, MetaSparseTensor)
586 
587   /// \brief Gets CSRTensor's indptr.
588   ///
589   /// \return [TensorPtr] The indices pointer.
590   TensorPtr GetIndptr() { return indptr_; }
591 
592   /// \brief Gets CSRTensor's indices.
593   ///
594   /// \return [TensorPtr] The indices.
GetIndices()595   TensorPtr GetIndices() { return indices_; }
596 
597   /// \brief Gets CSRTensor's values.
598   ///
599   /// \return [TensorPtr] The values.
GetValues()600   TensorPtr GetValues() { return values_; }
601 
602   /// \brief Compare two csrtensor objects to see if they have same data address.
603   ///
604   /// \param[in] csr_tensor The csrtensor object to be compared.
605   /// \return True if having same data address, otherwise false.
606   bool operator==(const CSRTensor &csr_tensor) const { return &csr_tensor == this; }
607 
608   bool operator==(const Value &other) const override {
609     if (other.isa<CSRTensor>()) {
610       auto &other_ = static_cast<const CSRTensor &>(other);
611       return *this == other_;
612     }
613     return false;
614   }
615 
616   const size_t GetSizeAt(size_t index) const;
617 
618   TensorPtr GetTensorAt(size_t index) const;
619 
GetTensorLength()620   const size_t GetTensorLength() const { return kShapeIdx + shape().size(); }
621 
622   /// \brief Get display information of this Tensor.
623   ///
624   /// \return The display information of this Tensor.
625   std::string ToString() const override;
626 
627   static constexpr size_t kIndptrIdx = 0;
628   static constexpr size_t kIndicesIdx = 1;
629   static constexpr size_t kValuesIdx = 2;
630   static constexpr size_t kShapeIdx = 3;
631 
632  private:
633   TensorPtr indptr_;
634   TensorPtr indices_;
635   TensorPtr values_;
636 };
637 using CSRTensorPtr = std::shared_ptr<CSRTensor>;
638 
639 // COOTensor entity class
640 class MS_CORE_API COOTensor : public MetaSparseTensor {
641  public:
642   abstract::AbstractBasePtr ToAbstract() override;
643 
644   /// \brief Create COOTensor with given data type from another tensor.
645   ///
646   /// \param[in] indices [Tensor] The indices.
647   /// \param[in] values [Tensor] The values.
648   /// \param[in] shape The shape represented by ShapeVector of the COOTensor.
COOTensor(const TensorPtr indices,const TensorPtr values,const ShapeVector & shape)649   COOTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
650       : MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {}
651 
652   /// Destructor of COOTensor.
653   ~COOTensor() override = default;
654 
MS_DECLARE_PARENT(COOTensor,MetaSparseTensor)655   MS_DECLARE_PARENT(COOTensor, MetaSparseTensor)
656 
657   /// \brief Gets COOTensor's indices.
658   ///
659   /// \return [TensorPtr] The indices.
660   TensorPtr GetIndices() { return indices_; }
661 
662   /// \brief Gets COOTensor's values.
663   ///
664   /// \return [TensorPtr] The values.
GetValues()665   TensorPtr GetValues() { return values_; }
666 
667   TensorPtr GetTensorAt(size_t index) const;
668 
GetTensorLength()669   const size_t GetTensorLength() const { return kShapeIdx + shape().size(); }
670 
671   /// \brief Compare two cootensor objects to see if they have same address.
672   ///
673   /// \param[in] coo_tensor The cootensor object to be compared.
674   /// \return True if having same data address, otherwise false.
675   bool operator==(const COOTensor &coo_tensor) const { return &coo_tensor == this; }
676 
677   bool operator==(const Value &other) const override {
678     if (other.isa<COOTensor>()) {
679       auto &other_ = static_cast<const COOTensor &>(other);
680       return *this == other_;
681     }
682     return false;
683   }
684 
685   /// \brief Get display information of this Tensor.
686   ///
687   /// \return The display information of this Tensor.
688   std::string ToString() const override;
689 
690   static constexpr size_t kIndicesIdx = 0;
691   static constexpr size_t kValuesIdx = 1;
692   static constexpr size_t kShapeIdx = 2;
693 
694  private:
695   TensorPtr indices_;
696   TensorPtr values_;
697 };
698 using COOTensorPtr = std::shared_ptr<COOTensor>;
699 
700 // RowTensor entity class
701 class MS_CORE_API RowTensor : public MetaSparseTensor {
702  public:
703   abstract::AbstractBasePtr ToAbstract() override;
704 
705   /// \brief Create RowTensor with given data type from another tensor.
706   ///
707   /// \param[in] indices [Tensor] The indices.
708   /// \param[in] values [Tensor] The values.
709   /// \param[in] shape The shape represented by ShapeVector of the RowTensor.
RowTensor(const TensorPtr indices,const TensorPtr values,const ShapeVector & shape)710   RowTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape)
711       : MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {}
712 
713   /// Destructor of RowTensor.
714   ~RowTensor() override = default;
715 
716   /// \brief Gets RowTensor's indices.
717   ///
718   /// \return [TensorPtr] The indices.
GetIndices()719   TensorPtr GetIndices() { return indices_; }
720 
721   /// \brief Gets RowTensor's values.
722   ///
723   /// \return [TensorPtr] The values.
GetValues()724   TensorPtr GetValues() { return values_; }
725 
726   /// \brief Compare two rowtensor objects to see if they have same address.
727   ///
728   /// \param[in] coo_tensor The rowtensor object to be compared.
729   /// \return True if having same data address, otherwise false.
730   bool operator==(const RowTensor &row_tensor) const { return &row_tensor == this; }
731 
732   bool operator==(const Value &other) const override {
733     if (other.isa<RowTensor>()) {
734       auto &other_ = static_cast<const RowTensor &>(other);
735       return *this == other_;
736     }
737     return false;
738   }
739 
740   /// \brief Get display information of this Tensor.
741   ///
742   /// \return The display information of this Tensor.
743   std::string ToString() const override;
744 
745  private:
746   TensorPtr indices_;
747   TensorPtr values_;
748 };
749 using RowTensorPtr = std::shared_ptr<RowTensor>;
750 }  // namespace tensor
751 }  // namespace mindspore
752 
753 #endif  // MINDSPORE_CORE_IR_TENSOR_H_
754