1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/map_tensor.h"
18 #include <vector>
19 #include <algorithm>
20 #include "abstract/abstract_value.h"
21 #include "ir/tensor.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_utils_secure.h"
24 #include "runtime/device/hash_table.h"
25
26 namespace mindspore {
27 using device::HashTable;
28 namespace tensor {
29 using tensor::Tensor;
30 using tensor::TensorPtr;
31 constexpr size_t kKeyTensorIndex = 0;
32 constexpr size_t kValueTensorIndex = 1;
33 constexpr size_t kStatusTensorIndex = 2;
34 constexpr size_t kExportTensorNum = 3;
35
ConcatShape(const ShapeVector & a,const ShapeVector & b)36 static ShapeVector ConcatShape(const ShapeVector &a, const ShapeVector &b) {
37 ShapeVector result_shape = a;
38 (void)result_shape.insert(result_shape.end(), b.cbegin(), b.cend());
39 return result_shape;
40 }
41
hash() const42 std::size_t MapTensor::hash() const { return static_cast<std::size_t>(tid()); }
43
operator ==(const MapTensor & other) const44 bool MapTensor::operator==(const MapTensor &other) const { return this == &other; }
45
ToAbstract()46 abstract::AbstractBasePtr MapTensor::ToAbstract() {
47 if (param_info_ != nullptr) {
48 // For parameter, a broaden abstract is created with ref_key set.
49 ValuePtr ref_key = std::make_shared<RefKey>(param_info_->name());
50 return std::make_shared<abstract::AbstractMapTensor>(shared_from_base<MapTensor>(), ref_key);
51 } else {
52 // For value, an abstract is created with value set.
53 return std::make_shared<abstract::AbstractMapTensor>(shared_from_base<MapTensor>());
54 }
55 }
56
ToString() const57 std::string MapTensor::ToString() const {
58 auto key_dtype = KeyDtype();
59 auto value_dtype = ValueDtype();
60 return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
61 ", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
62 ", value_shape=" + tensor::ShapeToString(value_shape()) +
63 ", default_value=" + (default_value_ == nullptr ? "<null>" : default_value_->ToString()) +
64 ", permit_filter=" + (permit_filter_value_ == nullptr ? "<null>" : permit_filter_value_->ToString()) +
65 ", evict_filter=" + (evict_filter_value_ == nullptr ? "<null>" : evict_filter_value_->ToString()) + ")";
66 }
67
Update(const MapTensor::ExportData & data)68 void MapTensor::Update(const MapTensor::ExportData &data) {
69 MS_EXCEPTION_IF_NULL(data.key_tensor);
70 MS_EXCEPTION_IF_NULL(data.value_tensor);
71 MS_EXCEPTION_IF_NULL(data.status_tensor);
72 key_tensor_ = data.key_tensor;
73 value_tensor_ = data.value_tensor;
74 status_tensor_ = data.status_tensor;
75 }
76
TransExportDataToTensor(const HashTableExportData & export_data) const77 void MapTensor::TransExportDataToTensor(const HashTableExportData &export_data) const {
78 if (export_data.size() != kExportTensorNum) {
79 MS_LOG(EXCEPTION) << "Invalid MapTensor export data.";
80 }
81
82 auto keys = export_data.at(kKeyTensorIndex);
83 auto values = export_data.at(kValueTensorIndex);
84 auto statuses = export_data.at(kStatusTensorIndex);
85 MS_EXCEPTION_IF_NULL(keys);
86 MS_EXCEPTION_IF_NULL(values);
87 MS_EXCEPTION_IF_NULL(statuses);
88
89 // The key tensor.
90 auto keys_length = keys->size();
91 auto keys_num = keys_length / abstract::TypeIdSize(key_dtype());
92 ShapeVector key_tensor_shape{SizeToLong(keys_num)};
93 auto tensor_key = key_tensor();
94 MS_EXCEPTION_IF_NULL(tensor_key);
95 (void)tensor_key->set_shape(key_tensor_shape);
96 if (keys_length > 0) {
97 auto ret = memcpy_s(tensor_key->data_c(), tensor_key->Size(), keys->data(), keys_length);
98 if (ret != EOK) {
99 MS_LOG(INTERNAL_EXCEPTION) << "Memcpy for key tensor failed, errno[" << ret << "]";
100 }
101 }
102
103 // The value tensor.
104 auto values_length = values->size();
105 auto element_length = LongToSize(abstract::ShapeSize(value_shape())) * abstract::TypeIdSize(value_dtype());
106 MS_EXCEPTION_IF_ZERO("element_length", element_length);
107 auto values_num = values_length / element_length;
108 ShapeVector value_tensor_shape{SizeToLong(values_num)};
109 (void)std::copy(value_shape().cbegin(), value_shape().cend(), std::back_inserter(value_tensor_shape));
110 auto tensor_value = value_tensor();
111 MS_EXCEPTION_IF_NULL(tensor_value);
112 (void)tensor_value->set_shape(value_tensor_shape);
113 if (values_length > 0) {
114 auto ret = memcpy_s(tensor_value->data_c(), tensor_value->Size(), values->data(), values_length);
115 if (ret != EOK) {
116 MS_LOG(INTERNAL_EXCEPTION) << "Memcpy for value tensor failed, errno[" << ret << "]";
117 }
118 }
119
120 // The status tensor
121 auto statuses_length = statuses->size();
122 auto statuses_num = statuses_length / abstract::TypeIdSize(kNumberTypeInt);
123 // The status tensor shape is same as the shape of key tensor.
124 if (statuses_num != keys_num) {
125 MS_LOG(INTERNAL_EXCEPTION) << "Invalid export data: keys num: " << keys_num << ", statuses num: " << statuses_num;
126 }
127 ShapeVector status_tensor_shape{SizeToLong(statuses_num)};
128 auto tensor_status = status_tensor();
129 MS_EXCEPTION_IF_NULL(tensor_status);
130 (void)tensor_status->set_shape(status_tensor_shape);
131 if (statuses_length > 0) {
132 auto ret = memcpy_s(tensor_status->data_c(), tensor_status->Size(), statuses->data(), statuses_length);
133 if (ret != EOK) {
134 MS_LOG(INTERNAL_EXCEPTION) << "Memcpy for status tensor failed, errno[" << ret << "]";
135 }
136 }
137 }
138
ExportDataFromDevice(const DeviceSyncPtr & device_sync,bool incremental,bool * last_slice) const139 MapTensor::ExportData MapTensor::ExportDataFromDevice(const DeviceSyncPtr &device_sync, bool incremental,
140 bool *last_slice) const {
141 auto user_data = device_sync->user_data();
142 MS_EXCEPTION_IF_NULL(user_data);
143 HashTableExportData export_data;
144 if (key_dtype() == TypeId::kNumberTypeInt32 && value_dtype() == TypeId::kNumberTypeFloat32) {
145 const auto &hash_table = user_data->get<HashTable<int, float>>(kUserDataData);
146 MS_EXCEPTION_IF_NULL(hash_table);
147 if (!hash_table->is_dirty()) {
148 return {key_tensor(), value_tensor(), status_tensor()};
149 }
150 if (last_slice) {
151 export_data = hash_table->ExportSlice(incremental, last_slice);
152 } else {
153 export_data = hash_table->Export(incremental);
154 }
155 } else if (key_dtype() == TypeId::kNumberTypeInt64 && value_dtype() == TypeId::kNumberTypeFloat32) {
156 const auto &hash_table = user_data->get<HashTable<int64_t, float>>(kUserDataData);
157 MS_EXCEPTION_IF_NULL(hash_table);
158 if (!hash_table->is_dirty()) {
159 return {key_tensor(), value_tensor(), status_tensor()};
160 }
161 if (last_slice) {
162 export_data = hash_table->ExportSlice(incremental, last_slice);
163 } else {
164 export_data = hash_table->Export(incremental);
165 }
166 } else {
167 MS_LOG(EXCEPTION) << "UnSupported Map Tensor type: key type is " << TypeIdToType(key_dtype()) << ", value type is "
168 << TypeIdToType(value_dtype()) << ".";
169 }
170 TransExportDataToTensor(export_data);
171
172 return {key_tensor(), value_tensor(), status_tensor()};
173 }
174
175 // If the data on the host side is valid, the data on the host side will be exported.
CheckData() const176 bool MapTensor::CheckData() const {
177 // check key
178 auto tensor_key = key_tensor();
179 MS_EXCEPTION_IF_NULL(tensor_key);
180 if (tensor_key->shape().size() != 1 || tensor_key->shape()[0] < 1) {
181 MS_LOG(WARNING) << "Invalid key tensor shape: " << tensor::ShapeToString(tensor_key->shape());
182 return false;
183 }
184 // check value
185 bool check_value =
186 std::any_of(value_shape().cbegin(), value_shape().cend(), [](const ShapeValueDType &shape) { return shape < 1; });
187 if (check_value) {
188 MS_LOG(WARNING) << "Invalid value tensor shape: " << tensor::ShapeToString(value_shape());
189 return false;
190 }
191 // check status
192 auto tensor_status = status_tensor();
193 MS_EXCEPTION_IF_NULL(tensor_status);
194 if (tensor_status->shape().size() != 1 || tensor_status->shape()[0] < 1) {
195 MS_LOG(WARNING) << "Invalid status tensor shape: " << tensor::ShapeToString(tensor_status->shape());
196 return false;
197 }
198 return true;
199 }
200
Export(bool incremental) const201 MapTensor::ExportData MapTensor::Export(bool incremental) const {
202 MS_LOG(DEBUG) << (incremental ? "Incremental" : "Full") << " export MapTensor";
203
204 // Check device
205 DeviceSyncPtr device_sync = device_address();
206 if (device_sync != nullptr) {
207 return ExportDataFromDevice(device_sync, incremental);
208 }
209 if (CheckData()) {
210 return {key_tensor(), value_tensor(), status_tensor()};
211 }
212 // Note: this is fake implementation.
213 ShapeVector key_shape = {1};
214 ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
215 auto key_tensor = std::make_shared<Tensor>(key_dtype(), key_shape);
216 auto value_tensor = std::make_shared<Tensor>(value_dtype(), values_shape);
217 auto status_tensor = std::make_shared<Tensor>(kNumberTypeInt, key_shape);
218 return {key_tensor, value_tensor, status_tensor};
219 }
220
ExportSlice(bool incremental,bool * last_slice) const221 MapTensor::ExportData MapTensor::ExportSlice(bool incremental, bool *last_slice) const {
222 MS_EXCEPTION_IF_NULL(last_slice);
223 DeviceSyncPtr device_sync = device_address();
224 MS_EXCEPTION_IF_NULL(device_sync);
225 return ExportDataFromDevice(device_sync, incremental, last_slice);
226 }
227 } // namespace tensor
228 } // namespace mindspore
229