1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_TENSOR_DATA_H_ 18 #define MINDSPORE_CORE_IR_TENSOR_DATA_H_ 19 20 #include <memory> 21 #include <string> 22 #include "mindapi/base/macros.h" 23 #include "utils/os.h" 24 25 namespace mindspore::tensor { 26 // Tensor data interface. 27 class MS_CORE_API TensorData { 28 public: 29 /// \brief Virtual destructor is required for base classes. 30 virtual ~TensorData() = default; 31 32 /// \brief Get total number of elements. 33 /// 34 /// \return Total number of elements. 35 virtual ssize_t size() const = 0; 36 37 /// \brief Get byte size of a single element. 38 /// 39 /// \return Byte size of a single element. 40 virtual ssize_t itemsize() const = 0; 41 42 /// \brief Get total number of bytes. 43 /// 44 /// \return Total number of bytes. 45 virtual ssize_t nbytes() const = 0; 46 47 /// \brief Get number of dimensions. 48 /// 49 /// \return Number of dimensions. 50 virtual ssize_t ndim() const = 0; 51 52 /// \brief Get data pointer. 53 /// 54 /// \return Data pointer. 55 virtual void *data() = 0; 56 57 /// \brief Get const data pointer. 58 /// 59 /// \return Const data pointer. 60 virtual const void *const_data() const = 0; 61 62 /// \brief Get whether this tensor data is sub data. 63 /// 64 /// \return Whether this tensor data is sub data. 65 virtual bool is_sub_data() const = 0; 66 67 /// \brief Check whether this tensor data has sub data. 68 /// 69 /// \return True if this tensor data has sub data, otherwise false. 70 virtual bool has_sub_data() const = 0; 71 72 /// \brief Get whether this tensor data is from numpy. 73 /// 74 /// \return Whether this tensor data is from numpy. is_from_numpy()75 virtual bool is_from_numpy() const { return false; } 76 77 /// \brief Get whether this tensor data have use persistent storage to save data. 78 /// 79 /// \return Whether this tensor data have use persistent storage to save data. is_persistent_data()80 virtual bool is_persistent_data() const { return false; } 81 82 /// \brief Whether the data are equal. 83 /// 84 /// \param[in] other Another TensorData. 85 /// \return Ture if the two data are equal, otherwise false. equals(const TensorData & other)86 virtual bool equals(const TensorData &other) const { 87 if (this == &other) { 88 return true; 89 } 90 // By default, compare data byte by byte. 91 auto this_data = static_cast<const uint8_t *>(const_data()); 92 auto other_data = static_cast<const uint8_t *>(other.const_data()); 93 if (this_data == nullptr || other_data == nullptr) { 94 // null means data not initialized, compare uninitialized data always return false. 95 return false; 96 } 97 return (this_data == other_data) || (ndim() == other.ndim() && nbytes() == other.nbytes() && 98 std::equal(this_data, this_data + nbytes(), other_data)); 99 } 100 101 /// \brief Get display information about this TensorData. 102 /// 103 /// \param[in] type The type of tensor data. 104 /// \param[in] shape The shape of tensor data. 105 /// \param[in] use_comma Whether to use comma. 106 /// \return The display information. 107 virtual std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const = 0; 108 109 /// \brief Set data saved file path. 110 /// 111 /// \param[in] data file path. 112 /// \return Void. set_file_path(const std::string & path)113 virtual void set_file_path(const std::string &path) { 114 MS_LOG(INFO) << "Call default set file path, and do nothing with " << path << "."; 115 } 116 117 /// \brief Get data saved file path. 118 /// 119 /// \return data file path. file_path()120 virtual const std::string file_path() const { return ""; } 121 }; 122 123 using TensorDataPtr = std::shared_ptr<TensorData>; 124 } // namespace mindspore::tensor 125 #endif // MINDSPORE_CORE_IR_TENSOR_DATA_H_ 126