• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/map_tensor.h"
18 #include <vector>
19 #include <algorithm>
20 #include "abstract/abstract_value.h"
21 #include "ir/tensor.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_utils_secure.h"
24 #include "runtime/device/hash_table.h"
25 
26 namespace mindspore {
27 using device::HashTable;
28 namespace tensor {
29 using tensor::Tensor;
30 using tensor::TensorPtr;
31 constexpr size_t kKeyTensorIndex = 0;
32 constexpr size_t kValueTensorIndex = 1;
33 constexpr size_t kStatusTensorIndex = 2;
34 constexpr size_t kExportTensorNum = 3;
35 
ConcatShape(const ShapeVector & a,const ShapeVector & b)36 static ShapeVector ConcatShape(const ShapeVector &a, const ShapeVector &b) {
37   ShapeVector result_shape = a;
38   (void)result_shape.insert(result_shape.end(), b.cbegin(), b.cend());
39   return result_shape;
40 }
41 
hash() const42 std::size_t MapTensor::hash() const { return static_cast<std::size_t>(tid()); }
43 
operator ==(const MapTensor & other) const44 bool MapTensor::operator==(const MapTensor &other) const { return this == &other; }
45 
ToAbstract()46 abstract::AbstractBasePtr MapTensor::ToAbstract() {
47   if (param_info_ != nullptr) {
48     // For parameter, a broaden abstract is created with ref_key set.
49     ValuePtr ref_key = std::make_shared<RefKey>(param_info_->name());
50     return std::make_shared<abstract::AbstractMapTensor>(shared_from_base<MapTensor>(), ref_key);
51   } else {
52     // For value, an abstract is created with value set.
53     return std::make_shared<abstract::AbstractMapTensor>(shared_from_base<MapTensor>());
54   }
55 }
56 
ToString() const57 std::string MapTensor::ToString() const {
58   auto key_dtype = KeyDtype();
59   auto value_dtype = ValueDtype();
60   return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
61          ", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
62          ", value_shape=" + tensor::ShapeToString(value_shape()) +
63          ", default_value=" + (default_value_ == nullptr ? "<null>" : default_value_->ToString()) +
64          ", permit_filter=" + (permit_filter_value_ == nullptr ? "<null>" : permit_filter_value_->ToString()) +
65          ", evict_filter=" + (evict_filter_value_ == nullptr ? "<null>" : evict_filter_value_->ToString()) + ")";
66 }
67 
Update(const MapTensor::ExportData & data)68 void MapTensor::Update(const MapTensor::ExportData &data) {
69   MS_EXCEPTION_IF_NULL(data.key_tensor);
70   MS_EXCEPTION_IF_NULL(data.value_tensor);
71   MS_EXCEPTION_IF_NULL(data.status_tensor);
72   key_tensor_ = data.key_tensor;
73   value_tensor_ = data.value_tensor;
74   status_tensor_ = data.status_tensor;
75 }
76 
TransExportDataToTensor(const HashTableExportData & export_data) const77 void MapTensor::TransExportDataToTensor(const HashTableExportData &export_data) const {
78   if (export_data.size() != kExportTensorNum) {
79     MS_LOG(EXCEPTION) << "Invalid MapTensor export data.";
80   }
81 
82   auto keys = export_data.at(kKeyTensorIndex);
83   auto values = export_data.at(kValueTensorIndex);
84   auto statuses = export_data.at(kStatusTensorIndex);
85   MS_EXCEPTION_IF_NULL(keys);
86   MS_EXCEPTION_IF_NULL(values);
87   MS_EXCEPTION_IF_NULL(statuses);
88 
89   // The key tensor.
90   auto keys_length = keys->size();
91   auto keys_num = keys_length / abstract::TypeIdSize(key_dtype());
92   ShapeVector key_tensor_shape{SizeToLong(keys_num)};
93   auto tensor_key = key_tensor();
94   MS_EXCEPTION_IF_NULL(tensor_key);
95   (void)tensor_key->set_shape(key_tensor_shape);
96   if (keys_length > 0) {
97     auto ret = memcpy_s(tensor_key->data_c(), tensor_key->Size(), keys->data(), keys_length);
98     if (ret != EOK) {
99       MS_LOG(INTERNAL_EXCEPTION) << "Memcpy for key tensor failed, errno[" << ret << "]";
100     }
101   }
102 
103   // The value tensor.
104   auto values_length = values->size();
105   auto element_length = LongToSize(abstract::ShapeSize(value_shape())) * abstract::TypeIdSize(value_dtype());
106   MS_EXCEPTION_IF_ZERO("element_length", element_length);
107   auto values_num = values_length / element_length;
108   ShapeVector value_tensor_shape{SizeToLong(values_num)};
109   (void)std::copy(value_shape().cbegin(), value_shape().cend(), std::back_inserter(value_tensor_shape));
110   auto tensor_value = value_tensor();
111   MS_EXCEPTION_IF_NULL(tensor_value);
112   (void)tensor_value->set_shape(value_tensor_shape);
113   if (values_length > 0) {
114     auto ret = memcpy_s(tensor_value->data_c(), tensor_value->Size(), values->data(), values_length);
115     if (ret != EOK) {
116       MS_LOG(INTERNAL_EXCEPTION) << "Memcpy for value tensor failed, errno[" << ret << "]";
117     }
118   }
119 
120   // The status tensor
121   auto statuses_length = statuses->size();
122   auto statuses_num = statuses_length / abstract::TypeIdSize(kNumberTypeInt);
123   // The status tensor shape is same as the shape of key tensor.
124   if (statuses_num != keys_num) {
125     MS_LOG(INTERNAL_EXCEPTION) << "Invalid export data: keys num: " << keys_num << ", statuses num: " << statuses_num;
126   }
127   ShapeVector status_tensor_shape{SizeToLong(statuses_num)};
128   auto tensor_status = status_tensor();
129   MS_EXCEPTION_IF_NULL(tensor_status);
130   (void)tensor_status->set_shape(status_tensor_shape);
131   if (statuses_length > 0) {
132     auto ret = memcpy_s(tensor_status->data_c(), tensor_status->Size(), statuses->data(), statuses_length);
133     if (ret != EOK) {
134       MS_LOG(INTERNAL_EXCEPTION) << "Memcpy for status tensor failed, errno[" << ret << "]";
135     }
136   }
137 }
138 
ExportDataFromDevice(const DeviceSyncPtr & device_sync,bool incremental,bool * last_slice) const139 MapTensor::ExportData MapTensor::ExportDataFromDevice(const DeviceSyncPtr &device_sync, bool incremental,
140                                                       bool *last_slice) const {
141   auto user_data = device_sync->user_data();
142   MS_EXCEPTION_IF_NULL(user_data);
143   HashTableExportData export_data;
144   if (key_dtype() == TypeId::kNumberTypeInt32 && value_dtype() == TypeId::kNumberTypeFloat32) {
145     const auto &hash_table = user_data->get<HashTable<int, float>>(kUserDataData);
146     MS_EXCEPTION_IF_NULL(hash_table);
147     if (!hash_table->is_dirty()) {
148       return {key_tensor(), value_tensor(), status_tensor()};
149     }
150     if (last_slice) {
151       export_data = hash_table->ExportSlice(incremental, last_slice);
152     } else {
153       export_data = hash_table->Export(incremental);
154     }
155   } else if (key_dtype() == TypeId::kNumberTypeInt64 && value_dtype() == TypeId::kNumberTypeFloat32) {
156     const auto &hash_table = user_data->get<HashTable<int64_t, float>>(kUserDataData);
157     MS_EXCEPTION_IF_NULL(hash_table);
158     if (!hash_table->is_dirty()) {
159       return {key_tensor(), value_tensor(), status_tensor()};
160     }
161     if (last_slice) {
162       export_data = hash_table->ExportSlice(incremental, last_slice);
163     } else {
164       export_data = hash_table->Export(incremental);
165     }
166   } else {
167     MS_LOG(EXCEPTION) << "UnSupported Map Tensor type: key type is " << TypeIdToType(key_dtype()) << ", value type is "
168                       << TypeIdToType(value_dtype()) << ".";
169   }
170   TransExportDataToTensor(export_data);
171 
172   return {key_tensor(), value_tensor(), status_tensor()};
173 }
174 
175 // If the data on the host side is valid, the data on the host side will be exported.
CheckData() const176 bool MapTensor::CheckData() const {
177   // check key
178   auto tensor_key = key_tensor();
179   MS_EXCEPTION_IF_NULL(tensor_key);
180   if (tensor_key->shape().size() != 1 || tensor_key->shape()[0] < 1) {
181     MS_LOG(WARNING) << "Invalid key tensor shape: " << tensor::ShapeToString(tensor_key->shape());
182     return false;
183   }
184   // check value
185   bool check_value =
186     std::any_of(value_shape().cbegin(), value_shape().cend(), [](const ShapeValueDType &shape) { return shape < 1; });
187   if (check_value) {
188     MS_LOG(WARNING) << "Invalid value tensor shape: " << tensor::ShapeToString(value_shape());
189     return false;
190   }
191   // check status
192   auto tensor_status = status_tensor();
193   MS_EXCEPTION_IF_NULL(tensor_status);
194   if (tensor_status->shape().size() != 1 || tensor_status->shape()[0] < 1) {
195     MS_LOG(WARNING) << "Invalid status tensor shape: " << tensor::ShapeToString(tensor_status->shape());
196     return false;
197   }
198   return true;
199 }
200 
Export(bool incremental) const201 MapTensor::ExportData MapTensor::Export(bool incremental) const {
202   MS_LOG(DEBUG) << (incremental ? "Incremental" : "Full") << " export MapTensor";
203 
204   // Check device
205   DeviceSyncPtr device_sync = device_address();
206   if (device_sync != nullptr) {
207     return ExportDataFromDevice(device_sync, incremental);
208   }
209   if (CheckData()) {
210     return {key_tensor(), value_tensor(), status_tensor()};
211   }
212   // Note: this is fake implementation.
213   ShapeVector key_shape = {1};
214   ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
215   auto key_tensor = std::make_shared<Tensor>(key_dtype(), key_shape);
216   auto value_tensor = std::make_shared<Tensor>(value_dtype(), values_shape);
217   auto status_tensor = std::make_shared<Tensor>(kNumberTypeInt, key_shape);
218   return {key_tensor, value_tensor, status_tensor};
219 }
220 
ExportSlice(bool incremental,bool * last_slice) const221 MapTensor::ExportData MapTensor::ExportSlice(bool incremental, bool *last_slice) const {
222   MS_EXCEPTION_IF_NULL(last_slice);
223   DeviceSyncPtr device_sync = device_address();
224   MS_EXCEPTION_IF_NULL(device_sync);
225   return ExportDataFromDevice(device_sync, incremental, last_slice);
226 }
227 }  // namespace tensor
228 }  // namespace mindspore
229