1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_TENSOR_H_ 18 #define MINDSPORE_CORE_IR_TENSOR_H_ 19 20 #include <future> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <numeric> 25 #include <mutex> 26 #include <condition_variable> 27 #include <utility> 28 #include "ir/device_sync.h" 29 #include "ir/base_tensor.h" 30 #include "utils/log_adapter.h" 31 #include "base/float16.h" 32 #include "base/bfloat16.h" 33 #include "utils/shape_utils.h" 34 #include "utils/ms_exception.h" 35 #include "ir/device_event.h" 36 #include "utils/os.h" 37 #include "ir/quantization_param.h" 38 #include "ir/meta_grad_data.h" 39 #include "ir/tensor_data.h" 40 41 // brief mindspore namespace. 42 // 43 // mindspore namespace is the top level namespace of MindSpore project. 44 // Other namespace should be a sub namespace of mindspore namespace in the ME project. 45 namespace mindspore { 46 // Pinned memory register interface. 47 class MS_CORE_API PinnedMemRegister { 48 public: 49 /// \brief Default constructor for register. 50 PinnedMemRegister() = default; 51 52 /// \brief Virtual destructor for register. 53 virtual ~PinnedMemRegister() = default; 54 55 /// \brief Register pinned memory. 56 /// 57 /// \param[in] addr The host address to pin. 58 /// \param[in] size The host data size. 59 /// \return Void. 60 virtual void RegisterPinnedMem(void *addr, size_t size) = 0; 61 62 /// \brief UnRegister pinned memory. 63 /// 64 /// \param[in] addr The host address to unpin. 65 /// \return Void. 66 virtual void UnRegisterPinnedMem(void *addr) = 0; 67 }; 68 69 // A sub namespace in ME to support tensor related definition. 70 namespace tensor { 71 class Tensor; 72 using TensorPtr = std::shared_ptr<Tensor>; 73 using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; 74 75 template <typename T> 76 class FutureData { 77 public: FutureData(std::shared_ptr<T> data,std::exception_ptr e_ptr)78 FutureData(std::shared_ptr<T> data, std::exception_ptr e_ptr) : data_(std::move(data)), e_ptr_(std::move(e_ptr)) {} ~FutureData()79 virtual ~FutureData() {} 80 GetData()81 virtual std::shared_ptr<T> GetData() const { return data_; } GetException()82 const std::exception_ptr &GetException() const { return e_ptr_; } 83 84 private: 85 std::shared_ptr<T> data_; 86 std::exception_ptr e_ptr_; 87 }; 88 89 template <typename T> 90 class FutureBase { 91 public: FutureBase(std::future<std::shared_ptr<tensor::FutureData<T>>> future)92 explicit FutureBase(std::future<std::shared_ptr<tensor::FutureData<T>>> future) : future_(std::move(future)) {} ~FutureBase()93 virtual ~FutureBase() {} 94 virtual std::shared_ptr<T> Get() = 0; 95 96 protected: 97 std::future<std::shared_ptr<tensor::FutureData<T>>> future_; 98 std::shared_ptr<tensor::FutureData<T>> future_data_; 99 }; 100 // brief Device info of Tensor 101 // 102 // Includes the format, data type and host format of a tensor. 103 struct DeviceInfo { 104 explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr, 105 std::string host_format = "DefaultFormat", int32_t device_id = 0) format_DeviceInfo106 : format_(std::move(format)), 107 data_type_(std::move(data_type)), 108 host_format_(std::move(host_format)), 109 device_id_(device_id) {} 110 std::string format_ = "DefaultFormat"; 111 TypePtr data_type_ = nullptr; 112 std::string host_format_ = "DefaultFormat"; 113 int32_t device_id_ = 0; 114 }; 115 116 // Tensor entity class 117 class MS_CORE_API Tensor : public BaseTensor { 118 public: 119 Tensor() = default; 120 121 /// \brief Create tensor from another tensor, data is shared. 122 /// 123 /// \param[in] tensor [Tensor] The input tensor. 124 explicit Tensor(const Tensor &tensor); 125 /// \brief Create tensor with given data type from another tensor. 126 /// 127 /// \param[in] tensor [Tensor] The input tensor. 128 /// \param[in] data_type [TypeId] The new tensor data type. 129 Tensor(const Tensor &tensor, TypeId data_type); 130 131 /// \brief Create tensor with base tensor. 132 /// 133 /// \param[in] tensor [Tensor] The input base tensor. 134 explicit Tensor(const BaseTensor &tensor); 135 136 /// \brief Create tensor with given data type from another tensor. 137 /// 138 /// \param[in] tensor [Tensor] The input tensor. 139 /// \param[in] data_type [TypeId] The new tensor data type. 140 Tensor(const BaseTensor &tensor, TypeId data_type); 141 142 /// \brief Create tensor with the given shared tensor data. 143 /// 144 /// \param[in] data_type [TypeId] Data type of the tensor. 145 /// \param[in] shape The shape represented by ShapeVector of the tensor. 146 /// \param[in] data The shared tensor data. 147 Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data); 148 149 /// \brief Create a lazy allocated tensor. 150 /// 151 /// \param[in] data_type [TypeId] Data type of the tensor. 152 /// \param[in] shape The shape represented by ShapeVector of the tensor. 153 Tensor(TypeId data_type, const ShapeVector &shape); 154 155 /// \brief Create a tensor with input data buffer. 156 /// 157 /// \param[in] data_type [TypeId] Data type of the tensor. 158 /// \param[in] shape The shape represented by ShapeVector of the tensor. 159 /// \param[in] data The input data to be copied into tensor. 160 /// \param[in] data_len The length of data in bytes. 161 Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len); 162 163 /// \brief Create a tensor with input data buffer and given source data type. 164 /// 165 /// \param[in] data_type [TypeId] Data type of the tensor. 166 /// \param[in] shape The shape represented by ShapeVector of the tensor. 167 /// \param[in] data The input data to be copied into tensor. 168 /// \param[in] src_data_type The source data type. 169 Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type); 170 171 /// \brief Create 1 dimension tensor from an int vector. 172 /// 173 /// \param[in] input [std::vector<int64_t>] the data for tensor. 174 /// \param[in] data_type [TypeId] data type. 175 explicit Tensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr); 176 177 /// \brief Create 1 dimension tensor from an int vector. 178 /// 179 /// \param[in] input [std::vector<int32_t>] the data for tensor. 180 /// \param[in] data_type [TypeId] data type. 181 explicit Tensor(const std::vector<int32_t> &input, const TypePtr &data_type = nullptr); 182 183 /// \brief Create 1 dimension tensor from a float vector. 184 /// 185 /// \param[in] input [std::vector<double>] the data for tensor. 186 /// \param[in] data_type [TypeId] data type. 187 explicit Tensor(const std::vector<double> &input, const TypePtr &data_type = nullptr); 188 189 /// \brief Create 1 dimension tensor from a float vector. 190 /// 191 /// \param[in] input [std::vector<float>] the data for tensor. 192 /// \param[in] data_type [TypeId] data type. 193 explicit Tensor(const std::vector<float> &input, const TypePtr &data_type = nullptr); 194 195 /// \brief Create 0 dimension tensor from an int64_t scalar. 196 /// 197 /// \param[in] input [int64] the data for tensor. 198 /// \param[in] data_type [TypeId] data type. 199 explicit Tensor(int64_t input, const TypePtr &data_type = nullptr); 200 201 /// \brief Create 0 dimension tensor from an int32_t scalar. 202 /// 203 /// \param[in] input [int32] the data for tensor. 204 /// \param[in] data_type [TypeId] data type. 205 explicit Tensor(int32_t input, const TypePtr &data_type = nullptr); 206 207 /// \brief Create 0 dimension tensor from an int16_t scalar. 208 /// 209 /// \param[in] input [int16] the data for tensor. 210 /// \param[in] data_type [TypeId] data type. 211 explicit Tensor(int16_t input, const TypePtr &data_type = nullptr); 212 213 /// \brief Create 0 dimension tensor from an int8_t scalar. 214 /// 215 /// \param[in] input [int8] the data for tensor. 216 /// \param[in] data_type [TypeId] data type. 217 explicit Tensor(int8_t input, const TypePtr &data_type = nullptr); 218 219 /// \brief Create 0 dimension tensor from a double scalar. 220 /// 221 /// \param[in] input [double] the data for tensor. 222 /// \param[in] data_type [TypeId] data type. 223 explicit Tensor(double input, const TypePtr &data_type = nullptr); 224 225 /// \brief Create 0 dimension tensor from a float scalar. 226 /// 227 /// \param[in] input [float] the data for tensor. 228 /// \param[in] data_type [TypeId] data type. 229 explicit Tensor(float input, const TypePtr &data_type = nullptr); 230 231 /// \brief Create 0 dimension tensor from a float16 scalar. 232 /// 233 /// \param[in] input [float16] the data for tensor. 234 /// \param[in] data_type [TypeId] data type. 235 explicit Tensor(float16 input, const TypePtr &data_type = nullptr); 236 237 /// \brief Create 0 dimension tensor from a bfloat16 scalar. 238 /// 239 /// \param[in] input [bfloat16] the data for tensor. 240 /// \param[in] data_type [TypeId] data type. 241 explicit Tensor(bfloat16 input, const TypePtr &data_type = nullptr); 242 243 /// \brief Create 0 dimension tensor from a uint64 scalar. 244 /// 245 /// \param[in] input [uint64] the data for tensor. 246 /// \param[in] data_type [TypeId] data type. 247 explicit Tensor(uint64_t input, const TypePtr &data_type = nullptr); 248 249 /// \brief Create 0 dimension tensor from a uint32 scalar. 250 /// 251 /// \param[in] input [uint32] the data for tensor. 252 /// \param[in] data_type [TypeId] data type. 253 explicit Tensor(uint32_t input, const TypePtr &data_type = nullptr); 254 255 /// \brief Create 0 dimension tensor from a uint16 scalar. 256 /// 257 /// \param[in] input [uint16] the data for tensor. 258 /// \param[in] data_type [TypeId] data type. 259 explicit Tensor(uint16_t input, const TypePtr &data_type = nullptr); 260 261 /// \brief Create 0 dimension tensor from a uint8 scalar. 262 /// 263 /// \param[in] input [uint8] the data for tensor. 264 /// \param[in] data_type [TypeId] data type. 265 explicit Tensor(uint8_t input, const TypePtr &data_type = nullptr); 266 267 /// \brief Create 0 dimension tensor from a bool scalar. 268 /// 269 /// \param[in] input [bool] the data for tensor. 270 /// \param[in] data_type [TypeId] data type. 271 explicit Tensor(bool input, const TypePtr &data_type = nullptr); 272 273 /// \brief Create a chunk tensor with the given data size. 274 /// 275 /// \param[in] data_type [TypeId] Data type of the tensor. 276 /// \param[in] data_size The tensor chunk data size in number of elements. 277 Tensor(TypeId data_type, size_t data_size); 278 279 /// \brief Create a Tensor which shape and size may be inconsistent, such as Tensor with compression data. 280 /// 281 /// \param[in] origin_data_type [TypeId] Data type of the origin tensor. 282 /// \param[in] shape The shape represented by ShapeVector of the tensor. 283 /// \param[in] compression_data_size The compression data buffer size. 284 /// \param[in] TensorCompressionType The tensor compression type. 285 Tensor(TypeId origin_data_type, const ShapeVector &shape, size_t compression_data_size, 286 TensorCompressionType compression_type); 287 288 Tensor &operator=(const Tensor &tensor); 289 290 /// Destructor of Tensor. 291 ~Tensor() override; 292 293 MS_DECLARE_PARENT(Tensor, BaseTensor); 294 295 /// \brief Compare two tensor objects to see if they have same data type, shape and data address. 296 /// 297 /// \param[in] tensor The Tensor object to be compared. 298 /// \return True if having same type, shape and data address, otherwise false. 299 bool operator==(const Tensor &tensor) const; 300 301 /// \brief Create Abstract for Tensor. 302 /// 303 /// \return Abstract of Tensor. 304 abstract::AbstractBasePtr ToAbstract() override; 305 306 /// \brief Assign value to this tensor. 307 /// 308 /// \param[in] tensor The input tensor. 309 /// \return Tensor with new value. 310 Tensor &AssignValue(const Tensor &tensor); 311 312 bool operator==(const Value &other) const override { 313 if (other.isa<Tensor>()) { 314 auto &other_ = static_cast<const Tensor &>(other); 315 return *this == other_; 316 } 317 return false; 318 } 319 320 /// \brief To synchronize data with the device, you need to wait for the data to be valid. 321 /// 322 void data_sync(bool need_wait = true) const; 323 324 /// \brief To synchronize data with the device without keeping device address, you need to wait for the data to be 325 /// valid. 326 /// 327 void data_sync_directly(const DeviceSync *const device_sync, bool need_wait = true) const; 328 329 /// \brief Check if this Tensor is initialized. 330 /// 331 /// \return Whether this Tensor is initialized. is_init()332 bool is_init() const { return init_flag_; } 333 334 /// \brief Set the initialization flag of this Tensor. 335 /// 336 /// \param[in] flag Whether this Tensor is initialized. set_init_flag(bool flag)337 void set_init_flag(bool flag) { init_flag_ = flag; } 338 339 /// \brief Check whether this Tensor needs to be converted. 340 /// 341 /// \return Whether this Tensor needs to be converted. is_adapter()342 bool is_adapter() const { return adapter_flag_; } 343 344 /// \brief Set the adapter flag of this Tensor. 345 /// 346 /// \param[in] flag Whether this Tensor needs to be converted. set_adapter_flag(bool flag)347 void set_adapter_flag(bool flag) { adapter_flag_ = flag; } 348 349 /// \brief Check whether to release device memory. 350 /// 351 /// \return Ture if need to release device memory, otherwise false. need_release_device_mem()352 bool need_release_device_mem() const { return need_release_device_mem_; } 353 354 /// \brief Set the flag to determine whether the device memory needs to be released. 355 /// 356 /// \param[in] release_device_mem If release_device_mem is ture, the device memory will to be released. set_need_release_device_mem(bool release_device_mem)357 void set_need_release_device_mem(bool release_device_mem) { need_release_device_mem_ = release_device_mem; } 358 359 /// \brief Get the cast dtype of this Tensor. 360 /// 361 /// \return The cast dtype of this Tensor. cast_dtype()362 TypePtr cast_dtype() { return cast_dtype_; } 363 364 /// \brief Set the cast dtype of this Tensor. 365 /// 366 /// \param[in] dtype The input cast dtype. 367 void set_cast_dtype(const TypePtr &dtype = nullptr) { cast_dtype_ = dtype; } 368 369 /// \brief Used cache_enable to update the tensor from the cache to the host. 370 /// 371 /// \return True if caching is enabled, otherwise false. cache_enable()372 bool cache_enable() const { return cache_enable_; } 373 374 /// \brief Set cache_enable. 375 /// 376 /// \param[in] cache_enable Whether to enable caching. 377 void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; } 378 379 /// \brief Get the pointer of hashmap tensor. 380 /// 381 /// \return The pointer of hashmap tensor. hashmap_tensor_ptr()382 std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; } 383 384 /// \brief Set the pointer of hashmap tensor. 385 /// 386 /// \param[in] hashmap_tensor_ptr The input pointer of hashmap tensor. 387 void set_hashmap_tensor_ptr(const std::shared_ptr<Tensor> &hashmap_tensor_ptr = nullptr) { 388 hashmap_tensor_ptr_ = hashmap_tensor_ptr; 389 } 390 391 /// \brief Get the pointer of cache tensor. 392 /// 393 /// \return The pointer of cache tensor. cache_tensor_ptr()394 std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; } 395 396 /// \brief Set the pointer of cache tensor. 397 /// 398 /// \param[in] cache_tensor_ptr The input pointer of cache tensor. 399 void set_cache_tensor_ptr(const std::shared_ptr<Tensor> &cache_tensor_ptr = nullptr) { 400 cache_tensor_ptr_ = cache_tensor_ptr; 401 } 402 403 /// \brief Check if this Tensor is the output of graph. 404 /// 405 /// \return Whether this Tensor is the output of graph IsGraphOutput()406 bool IsGraphOutput() const { return graph_output_; } 407 408 /// \brief Set whether this Tensor is the output of graph. SetIsGraphOutput()409 void SetIsGraphOutput() { graph_output_ = true; } 410 411 /// \brief Get whether this Tensor is updated by the device. 412 /// 413 /// \return Whether this Tensor is updated by the device. IsUpdatedByDevice()414 bool IsUpdatedByDevice() const { return updated_by_device_; } 415 416 /// \brief Set whether this Tensor is updated by the device. SetIsUpdateByDevice()417 void SetIsUpdateByDevice() { updated_by_device_ = true; } 418 419 /// \brief Get callback need to execute when value is updated of Tensor. 420 /// 421 /// \return The callback need to execute when value is updated of Tensor. update_value_callback()422 const std::function<void(const Tensor *)> &update_value_callback() const { return update_value_callback_; } 423 424 /// \brief Set callback need to execute when value is updated of Tensor. 425 /// 426 /// \param[in] update_value_callback The callback need to execute when value is updated of Tensor. set_update_value_callback(const std::function<void (const Tensor *)> & update_value_callback)427 void set_update_value_callback(const std::function<void(const Tensor *)> &update_value_callback) { 428 update_value_callback_ = update_value_callback; 429 } 430 431 /// \brief Get the memory chunk pointer and offset if memory chunk for this tensor exists. 432 /// 433 /// \return The memory chunk pointer and offset, nullptr and 0 if no memory chunk exists. 434 std::pair<void *, size_t> GetChunkOffset() const; 435 436 /// \brief Reset tensors data so that they are using contiguous memory chunks grouped by data type. 437 /// 438 /// \param[in] tensors The tensors to be processed. 439 /// \param[in] fusion_size Maximum memory chunk size in bytes, 0 for unlimited. 440 /// 441 /// \return Tensors that data are pointed to each contiguous memory chunks. 442 static TensorPtrList FlattenTensors(const TensorPtrList &tensors, size_t fusion_size = 0); 443 444 /// \brief Check if FlattenTensors called for the input tensors. 445 /// 446 /// \param[in] tensors The tensors to be checked. 447 /// 448 /// \return True if FlattenTensors called for input tensors, false otherwise. 449 static bool IsFlattened(const TensorPtrList &tensors); 450 451 /// \brief Get tensors for each contiguous memory chunks used by the input tensors. 452 /// 453 /// \param[in] tensors The input tensors. 454 /// 455 /// \return Tensors that data are pointed to each contiguous memory chunks, empty if failed. 456 static TensorPtrList GetFlattenedTensors(const TensorPtrList &tensors); 457 458 /// \brief Get tensors stub flag. 459 /// 460 /// \param[in] none. 461 /// 462 /// \return If compile with backend, return false, else return true. 463 static bool CheckStub(); 464 465 /// \brief Get the fusion size for the given flat tensors. 466 /// 467 /// \param[in] flat_tensors The input flat tensors. 468 /// 469 /// \return fusion size for the given flat tensors. 470 static size_t GetFusionSize(const TensorPtrList &flat_tensors); 471 472 /// \brief Get the tensor compression type. 473 /// 474 /// \return tensor compression type. compression_type()475 TensorCompressionType compression_type() const { return compression_type_; } 476 477 /// \brief If tensor use persistent tensor data. 478 /// 479 /// \return if use persistent tenor data. 480 bool is_persistent_data() const; 481 482 /// \brief Set tensor name. 483 /// 484 /// \param[in] tensor_name The tensor name. set_name(const std::string & tensor_name)485 void set_name(const std::string &tensor_name) { tensor_name_ = tensor_name; } 486 487 /// \brief Get the tensor name. 488 /// 489 /// \return tensor name. name()490 const std::string &name() const { return tensor_name_; } 491 492 /// \brief Set tensor quant param. 493 /// 494 /// \param[in] quant_param The tensor quant param. set_quant_param(const std::vector<std::shared_ptr<QuantizationParam>> & quant_params)495 void set_quant_param(const std::vector<std::shared_ptr<QuantizationParam>> &quant_params) { 496 quant_params_.assign(quant_params.begin(), quant_params.end()); 497 } 498 499 /// \brief Get the tensor quant param. 500 /// 501 /// \return tensor quant param. quant_params()502 const std::vector<std::shared_ptr<QuantizationParam>> &quant_params() const { return quant_params_; } 503 504 /// \brief Offload tensor to file. 505 /// 506 /// \return offload tensor success. 507 bool Offload(const std::string &file_path); 508 509 /// \brief Get tensor offload file path. 510 /// 511 /// \return offload file path, or empty string if tensor has not offload. 512 const std::string GetOffloadFilePath() const; 513 514 /// \brief pin tensor memory. 515 /// 516 /// \param[in] register to pin tensor data. 517 void PinMemory(PinnedMemRegister *pin_mem_register); 518 519 /// \brief unpin tensor memory. 520 void UnPinMemory(); 521 522 /// \brief Get tensor's device info. 523 /// 524 /// \return The device info of this tensor. device_info()525 DeviceInfo device_info() const { return device_info_; } 526 527 /// \brief Set tensor's device info. 528 /// 529 /// \param[in] device_info The tensor's device info. set_device_info(const DeviceInfo & device_info)530 void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } 531 532 /// \brief Set tensor's device info. 533 /// 534 /// \param[in] format The input format. 535 /// \param[in] data_type The input data type. 536 /// \param[in] host_format The input host format. 537 void SetDeviceInfo(const std::string &format, const TypePtr &data_type, 538 const std::string &host_format = "DefaultFormat"); 539 set_copy_done_flag(bool flag)540 void set_copy_done_flag(bool flag) { copy_done_flag_ = flag; } get_copy_done_flag()541 bool get_copy_done_flag() const { return copy_done_flag_; } 542 bool copy_done_flag_{false}; 543 544 private: 545 // Really execute callback function when host value is updated of Tensor. 546 void ExecuteUpdateValueCallback() const; 547 548 bool init_flag_{false}; 549 bool adapter_flag_{false}; 550 bool graph_output_{false}; 551 bool updated_by_device_{false}; 552 // Release device address of graph output tensor or not. 553 bool need_release_device_mem_{false}; 554 bool cache_enable_{false}; 555 std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr}; 556 std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr}; 557 TypePtr cast_dtype_{nullptr}; 558 std::function<void(const Tensor *)> update_value_callback_{nullptr}; 559 PinnedMemRegister *pin_mem_register_{nullptr}; 560 TensorCompressionType compression_type_{kNoCompression}; 561 std::vector<std::shared_ptr<QuantizationParam>> quant_params_; 562 std::string tensor_name_; 563 // brief Device info of Tensor 564 // 565 // Includes the format and data type of a tensor on device. 566 DeviceInfo device_info_; 567 }; 568 569 // CSRTensor entity class 570 class MS_CORE_API CSRTensor : public MetaSparseTensor { 571 public: 572 abstract::AbstractBasePtr ToAbstract() override; 573 574 /// \brief Create CSRTensor with given data type from another tensor. 575 /// 576 /// \param[in] indptr [Tensor] The indices pointer. 577 /// \param[in] indices [Tensor] The indices. 578 /// \param[in] values [Tensor] The values. 579 /// \param[in] shape The shape represented by ShapeVector of the CSRensor. 580 CSRTensor(const TensorPtr indptr, const TensorPtr indices, const TensorPtr values, const ShapeVector &shape); 581 582 /// Destructor of CSRTensor. 583 ~CSRTensor() override = default; 584 MS_DECLARE_PARENT(CSRTensor,MetaSparseTensor)585 MS_DECLARE_PARENT(CSRTensor, MetaSparseTensor) 586 587 /// \brief Gets CSRTensor's indptr. 588 /// 589 /// \return [TensorPtr] The indices pointer. 590 TensorPtr GetIndptr() { return indptr_; } 591 592 /// \brief Gets CSRTensor's indices. 593 /// 594 /// \return [TensorPtr] The indices. GetIndices()595 TensorPtr GetIndices() { return indices_; } 596 597 /// \brief Gets CSRTensor's values. 598 /// 599 /// \return [TensorPtr] The values. GetValues()600 TensorPtr GetValues() { return values_; } 601 602 /// \brief Compare two csrtensor objects to see if they have same data address. 603 /// 604 /// \param[in] csr_tensor The csrtensor object to be compared. 605 /// \return True if having same data address, otherwise false. 606 bool operator==(const CSRTensor &csr_tensor) const { return &csr_tensor == this; } 607 608 bool operator==(const Value &other) const override { 609 if (other.isa<CSRTensor>()) { 610 auto &other_ = static_cast<const CSRTensor &>(other); 611 return *this == other_; 612 } 613 return false; 614 } 615 616 const size_t GetSizeAt(size_t index) const; 617 618 TensorPtr GetTensorAt(size_t index) const; 619 GetTensorLength()620 const size_t GetTensorLength() const { return kShapeIdx + shape().size(); } 621 622 /// \brief Get display information of this Tensor. 623 /// 624 /// \return The display information of this Tensor. 625 std::string ToString() const override; 626 627 static constexpr size_t kIndptrIdx = 0; 628 static constexpr size_t kIndicesIdx = 1; 629 static constexpr size_t kValuesIdx = 2; 630 static constexpr size_t kShapeIdx = 3; 631 632 private: 633 TensorPtr indptr_; 634 TensorPtr indices_; 635 TensorPtr values_; 636 }; 637 using CSRTensorPtr = std::shared_ptr<CSRTensor>; 638 639 // COOTensor entity class 640 class MS_CORE_API COOTensor : public MetaSparseTensor { 641 public: 642 abstract::AbstractBasePtr ToAbstract() override; 643 644 /// \brief Create COOTensor with given data type from another tensor. 645 /// 646 /// \param[in] indices [Tensor] The indices. 647 /// \param[in] values [Tensor] The values. 648 /// \param[in] shape The shape represented by ShapeVector of the COOTensor. COOTensor(const TensorPtr indices,const TensorPtr values,const ShapeVector & shape)649 COOTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape) 650 : MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {} 651 652 /// Destructor of COOTensor. 653 ~COOTensor() override = default; 654 MS_DECLARE_PARENT(COOTensor,MetaSparseTensor)655 MS_DECLARE_PARENT(COOTensor, MetaSparseTensor) 656 657 /// \brief Gets COOTensor's indices. 658 /// 659 /// \return [TensorPtr] The indices. 660 TensorPtr GetIndices() { return indices_; } 661 662 /// \brief Gets COOTensor's values. 663 /// 664 /// \return [TensorPtr] The values. GetValues()665 TensorPtr GetValues() { return values_; } 666 667 TensorPtr GetTensorAt(size_t index) const; 668 GetTensorLength()669 const size_t GetTensorLength() const { return kShapeIdx + shape().size(); } 670 671 /// \brief Compare two cootensor objects to see if they have same address. 672 /// 673 /// \param[in] coo_tensor The cootensor object to be compared. 674 /// \return True if having same data address, otherwise false. 675 bool operator==(const COOTensor &coo_tensor) const { return &coo_tensor == this; } 676 677 bool operator==(const Value &other) const override { 678 if (other.isa<COOTensor>()) { 679 auto &other_ = static_cast<const COOTensor &>(other); 680 return *this == other_; 681 } 682 return false; 683 } 684 685 /// \brief Get display information of this Tensor. 686 /// 687 /// \return The display information of this Tensor. 688 std::string ToString() const override; 689 690 static constexpr size_t kIndicesIdx = 0; 691 static constexpr size_t kValuesIdx = 1; 692 static constexpr size_t kShapeIdx = 2; 693 694 private: 695 TensorPtr indices_; 696 TensorPtr values_; 697 }; 698 using COOTensorPtr = std::shared_ptr<COOTensor>; 699 700 // RowTensor entity class 701 class MS_CORE_API RowTensor : public MetaSparseTensor { 702 public: 703 abstract::AbstractBasePtr ToAbstract() override; 704 705 /// \brief Create RowTensor with given data type from another tensor. 706 /// 707 /// \param[in] indices [Tensor] The indices. 708 /// \param[in] values [Tensor] The values. 709 /// \param[in] shape The shape represented by ShapeVector of the RowTensor. RowTensor(const TensorPtr indices,const TensorPtr values,const ShapeVector & shape)710 RowTensor(const TensorPtr indices, const TensorPtr values, const ShapeVector &shape) 711 : MetaSparseTensor(values->data_type(), shape), indices_(indices), values_(values) {} 712 713 /// Destructor of RowTensor. 714 ~RowTensor() override = default; 715 716 /// \brief Gets RowTensor's indices. 717 /// 718 /// \return [TensorPtr] The indices. GetIndices()719 TensorPtr GetIndices() { return indices_; } 720 721 /// \brief Gets RowTensor's values. 722 /// 723 /// \return [TensorPtr] The values. GetValues()724 TensorPtr GetValues() { return values_; } 725 726 /// \brief Compare two rowtensor objects to see if they have same address. 727 /// 728 /// \param[in] coo_tensor The rowtensor object to be compared. 729 /// \return True if having same data address, otherwise false. 730 bool operator==(const RowTensor &row_tensor) const { return &row_tensor == this; } 731 732 bool operator==(const Value &other) const override { 733 if (other.isa<RowTensor>()) { 734 auto &other_ = static_cast<const RowTensor &>(other); 735 return *this == other_; 736 } 737 return false; 738 } 739 740 /// \brief Get display information of this Tensor. 741 /// 742 /// \return The display information of this Tensor. 743 std::string ToString() const override; 744 745 private: 746 TensorPtr indices_; 747 TensorPtr values_; 748 }; 749 using RowTensorPtr = std::shared_ptr<RowTensor>; 750 } // namespace tensor 751 } // namespace mindspore 752 753 #endif // MINDSPORE_CORE_IR_TENSOR_H_ 754