• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_MAP_TENSOR_H_
18 #define MINDSPORE_CORE_IR_MAP_TENSOR_H_
19 
20 #include <tuple>
21 #include <memory>
22 #include <vector>
23 #include <string>
24 #include <utility>
25 #include "ir/anf.h"
26 #include "ir/dtype.h"
27 #include "ir/tensor.h"
28 #include "ir/param_info.h"
29 #include "ir/scalar.h"
30 #include "mindapi/base/macros.h"
31 #include "utils/shape_utils.h"
32 #include "include/common/utils/utils.h"
33 
34 namespace mindspore {
35 namespace tensor {
36 class MapTensor;
37 // Smart pointer for MapTensor.
38 using MapTensorPtr = std::shared_ptr<MapTensor>;
39 ///
40 /// \brief MapTensor is a dynamic tensor with map like index functions.
41 ///
42 class MS_CORE_API MapTensor final : public Tensor {
43  public:
44   struct ExportData {
45     TensorPtr key_tensor;
46     TensorPtr value_tensor;
47     TensorPtr status_tensor;
48   };
49 
50   enum class Status {
51     kUnchanged = 0,
52     kModified = 1,
53     kErased = 2,
54   };
55 
56   MapTensor() = default;
57 
58   /// \brief Create a empty MapTensor.
59   ///
60   /// \param[in] key_dtype [TypeId] The key data type id.
61   /// \param[in] value_dtype [TypeId] The value data type id.
62   /// \param[in] value_shape [TypeId] The value shape.
63   /// \param[in] default_value [ValuePtr] The default value.
64   /// \param[in] permit_filter_value [ValuePtr] The permit filter value.
65   /// \param[in] evict_filter_value [ValuePtr] The evict filter value.
66   MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape, const ValuePtr &default_value,
67             const ValuePtr &permit_filter_value = nullptr, const ValuePtr &evict_filter_value = nullptr)
key_dtype_(key_dtype)68       : key_dtype_(key_dtype), default_value_(default_value) {
69     data_type_ = value_dtype;
70     value_shape_ = value_shape;
71     key_shape_ = {abstract::Shape::kShapeDimAny};
72     shape_ = {abstract::Shape::kShapeDimAny};
73     (void)shape_.insert(shape_.cend(), value_shape.cbegin(), value_shape.cend());
74     size_ = shape_[0];
75     ShapeVector key_shape = {abstract::Shape::kShapeDimAny};
76     key_tensor_ = std::make_shared<Tensor>(key_dtype, key_shape);
77     value_tensor_ = std::make_shared<Tensor>(value_dtype, shape_);
78     status_tensor_ = std::make_shared<Tensor>(kNumberTypeInt, key_shape);
79     permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int64Imm>(1) : permit_filter_value;
80     evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int64Imm>(INT64_MAX) : evict_filter_value;
81   }
82 
83   /// \brief Create a new MapTensor.
84   ///
85   /// \param[in] key_tensor [Tensor] The key tensor.
86   /// \param[in] value_tensor [Tensor] The value tensor.
87   /// \param[in] status_tensor [Tensor] The status tensor.
88   /// \param[in] default_value [ValuePtr] The default value.
89   /// \param[in] permit_filter_value [ValuePtr] The permit filter value.
90   /// \param[in] evict_filter_value [ValuePtr] The evict filter value.
91   MapTensor(const TensorPtr &key_tensor, const TensorPtr &value_tensor, const TensorPtr &status_tensor,
92             const ValuePtr &default_value, const ValuePtr &permit_filter_value = nullptr,
93             const ValuePtr &evict_filter_value = nullptr)
94       : key_dtype_(key_tensor->data_type()), default_value_(default_value) {
95     data_type_ = value_tensor->data_type();
96     shape_ = value_tensor->shape();
97     key_shape_ = key_tensor->shape();
98     value_shape_.clear();
99     (void)value_shape_.insert(value_shape_.cend(), shape_.cbegin() + 1, shape_.cend());
100     size_ = shape_.size() != 0 ? shape_[0] : (abstract::Shape::kShapeDimAny);
101     key_tensor_ = key_tensor;
102     value_tensor_ = value_tensor;
103     status_tensor_ = status_tensor;
104     permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int64Imm>(1) : permit_filter_value;
105     evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int64Imm>(INT64_MAX) : evict_filter_value;
106   }
107 
108   ~MapTensor() override = default;
109 
110   MS_DECLARE_PARENT(MapTensor, Tensor)
111 
112   std::size_t hash() const override;
113 
114   bool operator==(const Value &other) const override {
115     if (this == &other) {
116       return true;
117     }
118     if (!other.isa<MapTensor>()) {
119       return false;
120     }
121     auto &other_ = static_cast<const MapTensor &>(other);
122     return *this == other_;
123   }
124 
125   bool operator==(const MapTensor &other) const;
126 
key_dtype()127   TypeId key_dtype() const { return key_dtype_; }
128 
value_dtype()129   TypeId value_dtype() const { return data_type_; }
130 
size()131   int64_t size() const { return size_; }
132 
value_shape()133   const ShapeVector &value_shape() const { return value_shape_; }
134 
default_value()135   const ValuePtr &default_value() const { return default_value_; }
136 
permit_filter_value()137   const ValuePtr &permit_filter_value() const { return permit_filter_value_; }
138 
evict_filter_value()139   const ValuePtr &evict_filter_value() const { return evict_filter_value_; }
140 
KeyDtype()141   TypePtr KeyDtype() const { return TypeIdToType(key_dtype_); }
142 
ValueDtype()143   TypePtr ValueDtype() const { return TypeIdToType(data_type_); }
144 
145   abstract::AbstractBasePtr ToAbstract() override;
146 
147   std::string ToString() const override;
148 
149   /// \brief Update MapTensor from exported data.
150   ///
151   /// \param[in] data [ExportData] The data.
152   void Update(const ExportData &data);
153 
154   /// \brief Exported MapTensor data.
155   ///
156   /// \param[in] incremental [bool] False for incremental export, true for full export.
157   /// \return The exported data.
158   ExportData Export(bool incremental = false) const;
159 
160   /// \brief Exported slice data from MapTensor.
161   ///
162   /// \param[in] incremental [bool] False for incremental export, true for full export.
163   /// \param[out] last_slice [bool *] Point a bool variable which indicates whether the slice by export is the last
164   /// slice, that is, the export is complete and all slices are exported.
165   /// \return The exported data.
166   ExportData ExportSlice(bool incremental, bool *last_slice) const;
167 
168   /// \brief Exported MapTensor data from device.
169   ///
170   /// \param[in] device_sync [DeviceSyncPtr] The device resource synchronizer(such as DeviceAddress).
171   /// \param[in] incremental [bool] True for incremental export, false for full export.
172   /// \param[out] last_slice [bool *] Point a bool variable which indicates whether the slice by export is the last
173   /// slice, that is, the export is complete and all slices are exported. nullptr indicates that slice export is
174   /// disabled.
175   /// \return The exported data.
176   ExportData ExportDataFromDevice(const DeviceSyncPtr &device_sync, bool incremental, bool *last_slice = nullptr) const;
177 
178   /// \brief Get three tensor length from device data with tensor shape and type.
179   ///
180   /// \param[in] export_data [HashTableExportData] The export data buffer from device side.
181   void TransExportDataToTensor(const HashTableExportData &export_data) const;
182 
183   /// \brief Get the key tensor of MapTensor data.
184   ///
185   /// \return The key tensor.
key_tensor()186   const TensorPtr &key_tensor() const { return key_tensor_; }
187 
188   /// \brief Get the value tensor of MapTensor data.
189   ///
190   /// \return The value tensor.
value_tensor()191   const TensorPtr &value_tensor() const { return value_tensor_; }
192 
193   /// \brief Get the status tensor of MapTensor data.
194   ///
195   /// \return The status tensor.
status_tensor()196   const TensorPtr &status_tensor() const { return status_tensor_; }
197 
set_key_tensor(const TensorPtr key_tensor)198   void set_key_tensor(const TensorPtr key_tensor) { key_tensor_ = key_tensor; }
199 
set_value_tensor(const TensorPtr value_tensor)200   void set_value_tensor(const TensorPtr value_tensor) { value_tensor_ = value_tensor; }
201 
set_status_tensor(const TensorPtr status_tensor)202   void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
203 
204   bool CheckData() const;
205 
206  private:
207   // Data type of the keys.
208   TypeId key_dtype_;
209 
210   // The shape of keys.
211   ShapeVector key_shape_;
212 
213   // Default value. should be a scalar as the initial value or a string as the initializer name.
214   ValuePtr default_value_;
215 
216   // Permission threshold: When an element is accessed more than the threshold, it will be actually inserted into map.
217   ValuePtr permit_filter_value_;
218 
219   //  If the elements in the map are not used or updated within the time interval indicated by the threshold,
220   //  these elements will be removed from the map.
221   ValuePtr evict_filter_value_;
222 
223   // The shape of values
224   ShapeVector value_shape_;
225 
226   // The size of keys, shape_ is (size_, value_shape_).
227   int64_t size_;
228 
229   // Key tensor of data.
230   TensorPtr key_tensor_;
231 
232   // Value tensor of data.
233   TensorPtr value_tensor_;
234 
235   // Status tensor of data.
236   TensorPtr status_tensor_;
237 };
238 }  // namespace tensor
239 }  // namespace mindspore
240 #endif  // MINDSPORE_CORE_IR_MAP_TENSOR_H_
241