1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_MAP_TENSOR_H_ 18 #define MINDSPORE_CORE_IR_MAP_TENSOR_H_ 19 20 #include <tuple> 21 #include <memory> 22 #include <vector> 23 #include <string> 24 #include <utility> 25 #include "ir/anf.h" 26 #include "ir/dtype.h" 27 #include "ir/tensor.h" 28 #include "ir/param_info.h" 29 #include "ir/scalar.h" 30 #include "mindapi/base/macros.h" 31 #include "utils/shape_utils.h" 32 #include "include/common/utils/utils.h" 33 34 namespace mindspore { 35 namespace tensor { 36 class MapTensor; 37 // Smart pointer for MapTensor. 38 using MapTensorPtr = std::shared_ptr<MapTensor>; 39 /// 40 /// \brief MapTensor is a dynamic tensor with map like index functions. 41 /// 42 class MS_CORE_API MapTensor final : public Tensor { 43 public: 44 struct ExportData { 45 TensorPtr key_tensor; 46 TensorPtr value_tensor; 47 TensorPtr status_tensor; 48 }; 49 50 enum class Status { 51 kUnchanged = 0, 52 kModified = 1, 53 kErased = 2, 54 }; 55 56 MapTensor() = default; 57 58 /// \brief Create a empty MapTensor. 59 /// 60 /// \param[in] key_dtype [TypeId] The key data type id. 61 /// \param[in] value_dtype [TypeId] The value data type id. 62 /// \param[in] value_shape [TypeId] The value shape. 63 /// \param[in] default_value [ValuePtr] The default value. 64 /// \param[in] permit_filter_value [ValuePtr] The permit filter value. 65 /// \param[in] evict_filter_value [ValuePtr] The evict filter value. 66 MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape, const ValuePtr &default_value, 67 const ValuePtr &permit_filter_value = nullptr, const ValuePtr &evict_filter_value = nullptr) key_dtype_(key_dtype)68 : key_dtype_(key_dtype), default_value_(default_value) { 69 data_type_ = value_dtype; 70 value_shape_ = value_shape; 71 key_shape_ = {abstract::Shape::kShapeDimAny}; 72 shape_ = {abstract::Shape::kShapeDimAny}; 73 (void)shape_.insert(shape_.cend(), value_shape.cbegin(), value_shape.cend()); 74 size_ = shape_[0]; 75 ShapeVector key_shape = {abstract::Shape::kShapeDimAny}; 76 key_tensor_ = std::make_shared<Tensor>(key_dtype, key_shape); 77 value_tensor_ = std::make_shared<Tensor>(value_dtype, shape_); 78 status_tensor_ = std::make_shared<Tensor>(kNumberTypeInt, key_shape); 79 permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int64Imm>(1) : permit_filter_value; 80 evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int64Imm>(INT64_MAX) : evict_filter_value; 81 } 82 83 /// \brief Create a new MapTensor. 84 /// 85 /// \param[in] key_tensor [Tensor] The key tensor. 86 /// \param[in] value_tensor [Tensor] The value tensor. 87 /// \param[in] status_tensor [Tensor] The status tensor. 88 /// \param[in] default_value [ValuePtr] The default value. 89 /// \param[in] permit_filter_value [ValuePtr] The permit filter value. 90 /// \param[in] evict_filter_value [ValuePtr] The evict filter value. 91 MapTensor(const TensorPtr &key_tensor, const TensorPtr &value_tensor, const TensorPtr &status_tensor, 92 const ValuePtr &default_value, const ValuePtr &permit_filter_value = nullptr, 93 const ValuePtr &evict_filter_value = nullptr) 94 : key_dtype_(key_tensor->data_type()), default_value_(default_value) { 95 data_type_ = value_tensor->data_type(); 96 shape_ = value_tensor->shape(); 97 key_shape_ = key_tensor->shape(); 98 value_shape_.clear(); 99 (void)value_shape_.insert(value_shape_.cend(), shape_.cbegin() + 1, shape_.cend()); 100 size_ = shape_.size() != 0 ? shape_[0] : (abstract::Shape::kShapeDimAny); 101 key_tensor_ = key_tensor; 102 value_tensor_ = value_tensor; 103 status_tensor_ = status_tensor; 104 permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int64Imm>(1) : permit_filter_value; 105 evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int64Imm>(INT64_MAX) : evict_filter_value; 106 } 107 108 ~MapTensor() override = default; 109 110 MS_DECLARE_PARENT(MapTensor, Tensor) 111 112 std::size_t hash() const override; 113 114 bool operator==(const Value &other) const override { 115 if (this == &other) { 116 return true; 117 } 118 if (!other.isa<MapTensor>()) { 119 return false; 120 } 121 auto &other_ = static_cast<const MapTensor &>(other); 122 return *this == other_; 123 } 124 125 bool operator==(const MapTensor &other) const; 126 key_dtype()127 TypeId key_dtype() const { return key_dtype_; } 128 value_dtype()129 TypeId value_dtype() const { return data_type_; } 130 size()131 int64_t size() const { return size_; } 132 value_shape()133 const ShapeVector &value_shape() const { return value_shape_; } 134 default_value()135 const ValuePtr &default_value() const { return default_value_; } 136 permit_filter_value()137 const ValuePtr &permit_filter_value() const { return permit_filter_value_; } 138 evict_filter_value()139 const ValuePtr &evict_filter_value() const { return evict_filter_value_; } 140 KeyDtype()141 TypePtr KeyDtype() const { return TypeIdToType(key_dtype_); } 142 ValueDtype()143 TypePtr ValueDtype() const { return TypeIdToType(data_type_); } 144 145 abstract::AbstractBasePtr ToAbstract() override; 146 147 std::string ToString() const override; 148 149 /// \brief Update MapTensor from exported data. 150 /// 151 /// \param[in] data [ExportData] The data. 152 void Update(const ExportData &data); 153 154 /// \brief Exported MapTensor data. 155 /// 156 /// \param[in] incremental [bool] False for incremental export, true for full export. 157 /// \return The exported data. 158 ExportData Export(bool incremental = false) const; 159 160 /// \brief Exported slice data from MapTensor. 161 /// 162 /// \param[in] incremental [bool] False for incremental export, true for full export. 163 /// \param[out] last_slice [bool *] Point a bool variable which indicates whether the slice by export is the last 164 /// slice, that is, the export is complete and all slices are exported. 165 /// \return The exported data. 166 ExportData ExportSlice(bool incremental, bool *last_slice) const; 167 168 /// \brief Exported MapTensor data from device. 169 /// 170 /// \param[in] device_sync [DeviceSyncPtr] The device resource synchronizer(such as DeviceAddress). 171 /// \param[in] incremental [bool] True for incremental export, false for full export. 172 /// \param[out] last_slice [bool *] Point a bool variable which indicates whether the slice by export is the last 173 /// slice, that is, the export is complete and all slices are exported. nullptr indicates that slice export is 174 /// disabled. 175 /// \return The exported data. 176 ExportData ExportDataFromDevice(const DeviceSyncPtr &device_sync, bool incremental, bool *last_slice = nullptr) const; 177 178 /// \brief Get three tensor length from device data with tensor shape and type. 179 /// 180 /// \param[in] export_data [HashTableExportData] The export data buffer from device side. 181 void TransExportDataToTensor(const HashTableExportData &export_data) const; 182 183 /// \brief Get the key tensor of MapTensor data. 184 /// 185 /// \return The key tensor. key_tensor()186 const TensorPtr &key_tensor() const { return key_tensor_; } 187 188 /// \brief Get the value tensor of MapTensor data. 189 /// 190 /// \return The value tensor. value_tensor()191 const TensorPtr &value_tensor() const { return value_tensor_; } 192 193 /// \brief Get the status tensor of MapTensor data. 194 /// 195 /// \return The status tensor. status_tensor()196 const TensorPtr &status_tensor() const { return status_tensor_; } 197 set_key_tensor(const TensorPtr key_tensor)198 void set_key_tensor(const TensorPtr key_tensor) { key_tensor_ = key_tensor; } 199 set_value_tensor(const TensorPtr value_tensor)200 void set_value_tensor(const TensorPtr value_tensor) { value_tensor_ = value_tensor; } 201 set_status_tensor(const TensorPtr status_tensor)202 void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; } 203 204 bool CheckData() const; 205 206 private: 207 // Data type of the keys. 208 TypeId key_dtype_; 209 210 // The shape of keys. 211 ShapeVector key_shape_; 212 213 // Default value. should be a scalar as the initial value or a string as the initializer name. 214 ValuePtr default_value_; 215 216 // Permission threshold: When an element is accessed more than the threshold, it will be actually inserted into map. 217 ValuePtr permit_filter_value_; 218 219 // If the elements in the map are not used or updated within the time interval indicated by the threshold, 220 // these elements will be removed from the map. 221 ValuePtr evict_filter_value_; 222 223 // The shape of values 224 ShapeVector value_shape_; 225 226 // The size of keys, shape_ is (size_, value_shape_). 227 int64_t size_; 228 229 // Key tensor of data. 230 TensorPtr key_tensor_; 231 232 // Value tensor of data. 233 TensorPtr value_tensor_; 234 235 // Status tensor of data. 236 TensorPtr status_tensor_; 237 }; 238 } // namespace tensor 239 } // namespace mindspore 240 #endif // MINDSPORE_CORE_IR_MAP_TENSOR_H_ 241