1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "kernel/kernel.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <numeric>
23 #include <set>
24
25 #include "kernel/format_utils.h"
26 #include "kernel/common_utils.h"
27 #include "utils/ms_context.h"
28 #include "include/backend/device_synchronizer_utils.h"
29
30 namespace mindspore {
31 namespace kernel {
32 constexpr int64_t kInvalidShape = -2;
33
34 namespace {
35 using ShapeTransposeFunc = std::function<void(const ShapeVector *, ShapeVector *)>;
36
TransposeDefaultShape(const ShapeVector * host_shape_vector,ShapeVector * device_shape_vector)37 void TransposeDefaultShape(const ShapeVector *host_shape_vector, ShapeVector *device_shape_vector) {
38 MS_EXCEPTION_IF_NULL(host_shape_vector);
39 MS_EXCEPTION_IF_NULL(device_shape_vector);
40 *device_shape_vector = *host_shape_vector;
41 }
42
TransposeNCHWShape(const ShapeVector * host_shape_vector,ShapeVector * device_shape_vector)43 void TransposeNCHWShape(const ShapeVector *host_shape_vector, ShapeVector *device_shape_vector) {
44 MS_EXCEPTION_IF_NULL(host_shape_vector);
45 MS_EXCEPTION_IF_NULL(device_shape_vector);
46 if (host_shape_vector->size() != kDim4) {
47 MS_LOG(EXCEPTION) << "The host shape dims should be 4, but got: " << host_shape_vector->size();
48 }
49 *device_shape_vector = *host_shape_vector;
50 }
51
TransposeNHWCShape(const ShapeVector * host_shape_vector,ShapeVector * device_shape_vector)52 void TransposeNHWCShape(const ShapeVector *host_shape_vector, ShapeVector *device_shape_vector) {
53 MS_EXCEPTION_IF_NULL(host_shape_vector);
54 MS_EXCEPTION_IF_NULL(device_shape_vector);
55
56 if (host_shape_vector->size() != kDim4) {
57 MS_LOG(EXCEPTION) << "The host shape dims should be 4, but got: " << host_shape_vector->size();
58 }
59 device_shape_vector->resize(kDim4);
60
61 device_shape_vector->at(kIndex0) = host_shape_vector->at(kIndex0);
62 device_shape_vector->at(kIndex1) = host_shape_vector->at(kIndex2);
63 device_shape_vector->at(kIndex2) = host_shape_vector->at(kIndex3);
64 device_shape_vector->at(kIndex3) = host_shape_vector->at(kIndex1);
65 }
66 } // namespace
67
KernelHostInfo(const KernelHostInfo & other)68 KernelHostInfo::KernelHostInfo(const KernelHostInfo &other) {
69 shape_vector_after_format_trasform_ = other.shape_vector_after_format_trasform_;
70 type_id_ = other.type_id_;
71 kernel_tensor_value_ = other.kernel_tensor_value_;
72 }
73
KernelTensor()74 KernelTensor::KernelTensor() { address_common_ = std::make_shared<AddressCommon>(); }
75
KernelTensor(const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value)76 KernelTensor::KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value) {
77 host_info_ = std::make_unique<KernelHostInfo>();
78 address_common_ = std::make_shared<AddressCommon>();
79
80 if (type) {
81 SetType(type);
82 }
83 if (shape) {
84 // Note: for performance, the function `SetShape` uses host_info_->type_id_, so need to SetType first.
85 SetShape(shape);
86 }
87 if (value) {
88 SetValue(value);
89 }
90 }
91
KernelTensor(void * device_ptr,size_t size,Format format,TypeId dtype_id,const ShapeVector & host_shape,const string & device_name,uint32_t device_id,const UserDataPtr & user_data)92 KernelTensor::KernelTensor(void *device_ptr, size_t size, Format format, TypeId dtype_id, const ShapeVector &host_shape,
93 const string &device_name, uint32_t device_id, const UserDataPtr &user_data)
94 : host_shape_(host_shape),
95 user_data_(user_data),
96 address_common_(
97 std::make_shared<AddressCommon>(device_ptr, size, host_shape, format, dtype_id, device_name, device_id)) {
98 if (dtype_id == kTypeUnknown) {
99 SetType(TypeIdToType(dtype_id));
100 } else {
101 SetType(std::make_shared<TensorType>(TypeIdToType(dtype_id)));
102 }
103 }
104
KernelTensor(const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value,void * device_ptr,size_t size,const std::string & format,TypeId dtype_id,const ShapeVector & host_shape,const string & device_name,uint32_t device_id,const UserDataPtr & user_data)105 KernelTensor::KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value,
106 void *device_ptr, size_t size, const std::string &format, TypeId dtype_id,
107 const ShapeVector &host_shape, const string &device_name, uint32_t device_id,
108 const UserDataPtr &user_data)
109 : KernelTensor(shape, type, value) {
110 address_common_->pointer_ref_count_->set_ptr(device_ptr);
111 address_common_->size_ = size;
112 address_common_->format_ = GetFormatFromStrToEnum(format);
113 address_common_->dtype_id_ = dtype_id;
114 address_common_->device_name_ = device_name;
115 address_common_->device_id_ = device_id;
116 host_shape_ = host_shape;
117 user_data_ = user_data;
118 }
119
KernelTensor(const AddressCommonPtr & address_common,const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value,const ShapeVector & host_shape,const UserDataPtr & user_data)120 KernelTensor::KernelTensor(const AddressCommonPtr &address_common, const abstract::BaseShapePtr &shape,
121 const TypePtr &type, const ValuePtr &value, const ShapeVector &host_shape,
122 const UserDataPtr &user_data)
123 : KernelTensor(shape, type, value) {
124 address_common_ = address_common;
125 host_shape_ = host_shape;
126 user_data_ = user_data;
127 }
128
KernelTensor(const KernelTensor & other)129 KernelTensor::KernelTensor(const KernelTensor &other) {
130 // Copy host info.
131 shape_ = other.shape_ != nullptr ? other.shape_->Clone() : abstract::kNoShape;
132 type_ = other.shape_ != nullptr ? other.type_->Clone() : kTypeAny;
133 value_ = other.value_;
134
135 if (other.host_info_) {
136 host_info_ = std::make_unique<KernelHostInfo>(*other.host_info_);
137 host_info_->kernel_tensor_value_ = other.host_info_->kernel_tensor_value_ != nullptr
138 ? std::make_shared<KernelTensorValue>(*other.host_info_->kernel_tensor_value_)
139 : nullptr;
140 }
141
142 // Copy device info.
143 task_id_on_stream_ = other.task_id_on_stream_;
144 address_common_ = std::make_shared<AddressCommon>(*other.address_common_);
145 device_synchronizer_ = other.device_synchronizer_;
146 host_shape_ = other.host_shape_;
147 user_data_ = other.user_data_;
148 }
149
CheckHostInfoValid()150 inline void KernelTensor::CheckHostInfoValid() {
151 if (MS_UNLIKELY(!host_info_)) {
152 host_info_ = std::make_unique<KernelHostInfo>();
153 }
154 }
155 namespace {
GetShapeVectorByBaseShape(const abstract::BaseShapePtr & base_shape)156 ShapeVector GetShapeVectorByBaseShape(const abstract::BaseShapePtr &base_shape) {
157 MS_EXCEPTION_IF_NULL(base_shape);
158 if (base_shape->isa<abstract::NoShape>()) {
159 return {};
160 } else if (base_shape->isa<abstract::Shape>()) {
161 return base_shape->cast<abstract::ShapePtr>()->shape();
162 } else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
163 return {-1};
164 } else if (base_shape->isa<abstract::SequenceShape>()) {
165 const auto &sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
166 MS_EXCEPTION_IF_NULL(sequence_shape);
167 if (sequence_shape->size() == 0) {
168 return {0};
169 }
170 ShapeVector shape_vector = {SizeToLong(sequence_shape->size())};
171 const auto &sub_shape_vector = GetShapeVectorByBaseShape(sequence_shape->shape()[0]);
172 shape_vector.insert(shape_vector.end(), sub_shape_vector.begin(), sub_shape_vector.end());
173 return shape_vector;
174 }
175 MS_LOG(EXCEPTION) << "Invalid shape:" << base_shape->ToString();
176 }
177 } // namespace
178
SetHostInfo(const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value)179 void KernelTensor::SetHostInfo(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value) {
180 CheckHostInfoValid();
181 if (type) {
182 SetType(type);
183 }
184 if (shape) {
185 SetShape(shape);
186 }
187 if (value) {
188 SetValue(value);
189 }
190 }
191
SetShape(const abstract::BaseShapePtr & shape)192 void KernelTensor::SetShape(const abstract::BaseShapePtr &shape) {
193 MS_EXCEPTION_IF_NULL(shape);
194 shape_ = shape;
195 CheckHostInfoValid();
196
197 // Note: for performance, the function `SetShape` uses host_info_->type_id_, so need to SetType first.
198 switch (host_info_->type_id_) {
199 case kObjectTypeMapTensorType:
200 case kObjectTypeTensorType: {
201 // The shape type check will affect the performance. The following check will be deleted after the framework is
202 // stable.
203 if (shape_->isa<abstract::NoShape>()) {
204 address_common_->shape_vector_ = {};
205 } else {
206 if (!shape_->isa<abstract::TensorShape>()) {
207 MS_LOG(EXCEPTION) << "Expected TensorShape for SetShape, but got: " << shape_->type_name() << ", "
208 << shape_->ToString();
209 }
210 address_common_->shape_vector_ = shape_->GetShapeVector();
211 }
212
213 break;
214 }
215
216 case kObjectTypeList:
217 case kObjectTypeTuple: {
218 if (shape->isa<abstract::DynamicSequenceShape>()) {
219 address_common_->shape_vector_ = {-1};
220 break;
221 }
222 const auto &seq_shape = shape_->cast<abstract::SequenceShapePtr>();
223 if (seq_shape == nullptr) {
224 MS_LOG(EXCEPTION) << "Expected SequenceShape for SetShape, but got: " << shape_->type_name() << ", "
225 << shape_->ToString();
226 }
227 address_common_->shape_vector_.clear();
228 address_common_->shape_vector_.push_back(seq_shape->size());
229 const auto &shapes = seq_shape->shape();
230 if (shapes.empty()) {
231 break;
232 }
233 const auto &element_shape = shapes[0];
234 MS_EXCEPTION_IF_NULL(element_shape);
235 if (element_shape->isa<abstract::TensorShape>()) {
236 const ShapeVector &element_shape_vector = element_shape->GetShapeVector();
237 address_common_->shape_vector_.insert(address_common_->shape_vector_.end(), element_shape_vector.begin(),
238 element_shape_vector.end());
239 } else if (element_shape->isa<abstract::SequenceShape>()) {
240 const ShapeVector &element_shape_vector = GetShapeVectorByBaseShape(element_shape);
241 address_common_->shape_vector_.insert(address_common_->shape_vector_.end(), element_shape_vector.begin(),
242 element_shape_vector.end());
243 }
244
245 break;
246 }
247
248 case kTypeUnknown: {
249 MS_LOG(EXCEPTION) << "Can not set shape for unknown type, please set correct type for kernel tensor first.";
250 }
251
252 default:
253 MS_EXCEPTION_IF_NULL(type_);
254 MS_LOG(DEBUG) << "Need not set shape for: " << type_->ToString();
255 }
256
257 // Update size_ after shape changed.
258 // Note: calculate memory size should be executed after 'SetType' and 'SetShape'.
259 CalculateMemSize();
260 }
261
CalculateMemSize()262 void KernelTensor::CalculateMemSize() {
263 MS_EXCEPTION_IF_NULL(host_info_);
264 if (host_info_->type_id_ == kObjectTypeTensorType || host_info_->type_id_ == kObjectTypeTuple ||
265 host_info_->type_id_ == kObjectTypeList) {
266 // If address_common_->shape_vector_ is a dynamic shape, device_info_->size_ will be 0.
267 size_t element_num = SizeOf(address_common_->shape_vector_);
268 address_common_->size_ = element_num * UnitSizeInBytes(address_common_->dtype_id_);
269 } else if (host_info_->type_id_ == kObjectTypeNumber) {
270 address_common_->size_ = UnitSizeInBytes(address_common_->dtype_id_);
271 }
272 }
273
SetShapeVector(const ShapeVector & shape_vector)274 void KernelTensor::SetShapeVector(const ShapeVector &shape_vector) {
275 CheckHostInfoValid();
276 if (host_info_->type_id_ == kObjectTypeTensorType || host_info_->type_id_ == kObjectTypeMapTensorType) {
277 address_common_->shape_vector_ = shape_vector;
278 MS_EXCEPTION_IF_NULL(shape_);
279 shape_->SetShapeVector(address_common_->shape_vector_);
280
281 MS_LOG(DEBUG) << "Set shape vector: " << shape_vector
282 << ", the format: " << GetFormatFromEnumToStr(address_common_->format_);
283 return;
284 }
285
286 if (host_info_->type_id_ == kObjectTypeNumber) {
287 if (!shape_vector.empty()) {
288 MS_LOG(EXCEPTION) << "For Number Type, shape should be empty, but got " << shape_vector;
289 }
290 return;
291 }
292
293 MS_LOG(EXCEPTION) << "Only support Scalar/Tensor/MapTensor type to set shape vector currently, but got type: "
294 << TypeIdLabel(host_info_->type_id_);
295 }
296
SetShapeVector(ShapeVector && shape_vector)297 void KernelTensor::SetShapeVector(ShapeVector &&shape_vector) {
298 CheckHostInfoValid();
299 if (host_info_->type_id_ == kObjectTypeTensorType || host_info_->type_id_ == kObjectTypeMapTensorType) {
300 address_common_->shape_vector_ = std::move(shape_vector);
301 MS_EXCEPTION_IF_NULL(shape_);
302 shape_->SetShapeVector(address_common_->shape_vector_);
303
304 MS_LOG(DEBUG) << "Set shape vector: " << shape_vector
305 << ", the format: " << GetFormatFromEnumToStr(address_common_->format_);
306 return;
307 }
308
309 if (host_info_->type_id_ == kObjectTypeNumber) {
310 if (!shape_vector.empty()) {
311 MS_LOG(EXCEPTION) << "For Number Type, shape should be empty, but got " << shape_vector;
312 }
313 return;
314 }
315
316 MS_LOG(EXCEPTION) << "Only support Scalar/Tensor/MapTensor type to set shape vector currently, but got type: "
317 << TypeIdLabel(host_info_->type_id_);
318 }
319
TransposeToDeviceShape() const320 const ShapeVector &KernelTensor::TransposeToDeviceShape() const {
321 MS_EXCEPTION_IF_NULL(host_info_);
322 if (host_info_->type_id_ != kObjectTypeTensorType) {
323 MS_LOG(EXCEPTION) << "Only TensorType could transpose device shape, but got: " << TypeIdLabel(host_info_->type_id_);
324 }
325
326 static const mindspore::HashMap<mindspore::Format, ShapeTransposeFunc> shape_trans_funcs = {
327 {Format::DEFAULT_FORMAT, TransposeDefaultShape},
328 {Format::NCHW, TransposeNCHWShape},
329 {Format::NHWC, TransposeNHWCShape}};
330
331 auto iter = shape_trans_funcs.find(address_common_->format_);
332 if (iter == shape_trans_funcs.end()) {
333 MS_LOG(EXCEPTION) << "Can not find shape transpose function for format: "
334 << GetFormatFromEnumToStr(address_common_->format_);
335 }
336
337 // The shape of the device corresponding to 'address_common_->shape_vector_'. For example, if format is NHWC, the
338 // shape of the device and host may be different.
339 iter->second(&address_common_->shape_vector_, &host_info_->shape_vector_after_format_trasform_);
340 return host_info_->shape_vector_after_format_trasform_;
341 }
342
NeedTransposeToDeviceShape() const343 bool KernelTensor::NeedTransposeToDeviceShape() const noexcept {
344 static std::set<mindspore::Format> black_list{Format::DEFAULT_FORMAT, Format::NCHW, Format::ND, Format::NCDHW};
345 auto it = black_list.find(address_common_->format_);
346 return it == black_list.end();
347 }
348
GetDeviceShapeVector() const349 const ShapeVector &KernelTensor::GetDeviceShapeVector() const {
350 MS_EXCEPTION_IF_NULL(host_info_);
351 if (NeedTransposeToDeviceShape()) {
352 std::lock_guard<std::mutex> lock(host_info_->shape_transform_mutex_);
353 return TransposeToDeviceShape();
354 }
355 return address_common_->shape_vector_;
356 }
357
SetType(const TypePtr & type)358 void KernelTensor::SetType(const TypePtr &type) {
359 MS_EXCEPTION_IF_NULL(type);
360 CheckHostInfoValid();
361 type_ = type;
362 host_info_->type_id_ = type_->object_type();
363 if (host_info_->type_id_ == kTypeUnknown) {
364 host_info_->type_id_ = type_->type_id();
365 MS_EXCEPTION_IF_CHECK_FAIL((host_info_->type_id_ != kTypeUnknown),
366 "Got a unknown type id, type info: " + type_->ToString());
367 }
368
369 switch (host_info_->type_id_) {
370 case kObjectTypeTensorType: {
371 auto tensor_type_ptr = type_->cast<TensorTypePtr>();
372 MS_EXCEPTION_IF_NULL(tensor_type_ptr);
373 auto element_type = tensor_type_ptr->element();
374 if (element_type) {
375 address_common_->dtype_id_ = element_type->type_id();
376 }
377 } break;
378
379 case kObjectTypeTuple: {
380 auto tuple_type = type_->cast<TuplePtr>();
381 MS_EXCEPTION_IF_NULL(tuple_type);
382 TypePtr element_type = nullptr;
383 if (tuple_type->dynamic_len()) {
384 element_type = tuple_type->dynamic_element_type();
385 if (element_type == nullptr) {
386 return;
387 }
388 } else {
389 const TypePtrList &element_types = tuple_type->elements();
390 if (element_types.empty()) {
391 return;
392 }
393 element_type = element_types[0];
394 }
395 SetSequenceDType(element_type);
396 } break;
397
398 case kObjectTypeList: {
399 auto list_type = type_->cast<ListPtr>();
400 MS_EXCEPTION_IF_NULL(list_type);
401 TypePtr element_type = nullptr;
402 if (list_type->dynamic_len()) {
403 element_type = list_type->dynamic_element_type();
404 if (element_type == nullptr) {
405 return;
406 }
407 } else {
408 const TypePtrList &element_types = list_type->elements();
409 if (element_types.empty()) {
410 return;
411 }
412 element_type = element_types[0];
413 }
414 SetSequenceDType(element_type);
415 } break;
416
417 default:
418 address_common_->dtype_id_ = type->type_id();
419 MS_LOG(DEBUG) << "Set dtype for: " << type->ToString();
420 }
421 }
422
SetSequenceDType(const TypePtr & element_type)423 void KernelTensor::SetSequenceDType(const TypePtr &element_type) {
424 MS_EXCEPTION_IF_NULL(element_type);
425 if (element_type->object_type() == kObjectTypeTensorType) {
426 // Tensor type element.
427 auto tensor_type_ptr = element_type->cast<TensorTypePtr>();
428 MS_EXCEPTION_IF_NULL(tensor_type_ptr);
429 auto tensor_element_type = tensor_type_ptr->element();
430 if (tensor_element_type) {
431 address_common_->dtype_id_ = tensor_element_type->type_id();
432 }
433 } else if (element_type->object_type() == kObjectTypeNumber) {
434 // Scalar type element.
435 address_common_->dtype_id_ = element_type->type_id();
436 } else if (element_type->object_type() == kObjectTypeString) {
437 // String type element.
438 address_common_->dtype_id_ = element_type->type_id();
439 } else if (element_type->object_type() == kObjectTypeTuple) {
440 // Sequence type element.
441 auto tuple_type = element_type->cast<TuplePtr>();
442 MS_EXCEPTION_IF_NULL(tuple_type);
443 if (tuple_type->dynamic_len()) {
444 if (tuple_type->dynamic_element_type() == nullptr) {
445 return;
446 }
447 SetSequenceDType(tuple_type->dynamic_element_type());
448 return;
449 }
450 const TypePtrList &element_types = tuple_type->elements();
451 if (element_types.empty() || element_types[0] == nullptr) {
452 return;
453 }
454 SetSequenceDType(element_types[0]);
455 return;
456 } else if (element_type->object_type() == kObjectTypeList) {
457 // Sequence type element.
458 auto list_type = element_type->cast<ListPtr>();
459 MS_EXCEPTION_IF_NULL(list_type);
460 if (list_type->dynamic_len()) {
461 if (list_type->dynamic_element_type() == nullptr) {
462 return;
463 }
464 SetSequenceDType(list_type->dynamic_element_type());
465 return;
466 }
467 const TypePtrList &element_types = list_type->elements();
468 if (element_types.empty() || element_types[0] == nullptr) {
469 return;
470 }
471 SetSequenceDType(element_types[0]);
472 return;
473 } else {
474 MS_LOG(EXCEPTION) << "Unsupported element type[" << element_type->ToString()
475 << "] to set element data type for KernelTensor.";
476 }
477 }
478
GetStringFormat() const479 std::string KernelTensor::GetStringFormat() const { return GetFormatFromEnumToStr(address_common_->format_); }
480
SetStringFormat(const std::string & format)481 void KernelTensor::SetStringFormat(const std::string &format) {
482 address_common_->format_ = GetFormatFromStrToEnum(format);
483 }
484
GetValue() const485 ValuePtr KernelTensor::GetValue() const {
486 MS_EXCEPTION_IF_NULL(host_info_);
487 std::lock_guard<std::mutex> lock(host_info_->value_mutex_);
488
489 // There is a origin value in KernelTensor(maybe come from a ValueNode).
490 if (address_common_->dtype_id_ == kMetaTypeNone) {
491 return kNone;
492 } else if (value_ && !value_->isa<ValueAny>()) {
493 if (host_info_->kernel_tensor_value_ == nullptr) {
494 host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_);
495 return host_info_->kernel_tensor_value_ ? host_info_->kernel_tensor_value_ : value_;
496 }
497 return host_info_->kernel_tensor_value_;
498 }
499
500 // Sync value data from device.
501 if (!SyncDataFromDeviceToHost()) {
502 MS_LOG(EXCEPTION) << "Sync data from device to host side failed";
503 }
504 return host_info_->kernel_tensor_value_;
505 }
506
GetValuePtr()507 const void *KernelTensor::GetValuePtr() {
508 CheckHostInfoValid();
509 std::lock_guard<std::mutex> lock(host_info_->value_mutex_);
510
511 // There is a origin value in KernelTensor(maybe come from a ValueNode).
512 if (address_common_->dtype_id_ == kMetaTypeNone) {
513 return nullptr;
514 } else if (value_ && !value_->isa<ValueAny>()) {
515 if (host_info_->kernel_tensor_value_ == nullptr) {
516 host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_);
517 }
518 MS_EXCEPTION_IF_NULL(host_info_->kernel_tensor_value_);
519 return host_info_->kernel_tensor_value_->GetDataPtr();
520 }
521
522 // Sync value data from device.
523 if (!SyncDataFromDeviceToHost()) {
524 MS_LOG(EXCEPTION) << "Sync data from device to host side failed";
525 }
526 return host_info_->kernel_tensor_value_->GetDataPtr();
527 }
528
SyncDataFromDeviceToHost() const529 bool KernelTensor::SyncDataFromDeviceToHost() const {
530 // Note: must release lock when wait async resize or launch kernel finish, because the kernels' resize and launch
531 // tasks which are waited maybe use this kernel's GetValue and try lock this mutex to avoid deadlock.
532 host_info_->value_mutex_.unlock();
533 WaitAsyncResizeAndLaunchFinish();
534 host_info_->value_mutex_.lock();
535
536 void *device_ptr = this->device_ptr();
537 if (device_ptr == nullptr) {
538 MS_LOG(ERROR) << "Not malloc device memory yet, sync data from device to host side failed, size: "
539 << address_common_->size_;
540 return false;
541 }
542
543 MS_EXCEPTION_IF_NULL(host_info_);
544 // For performance, the CPU back-end does not need to copy the device to host, and directly uses the
545 // device pointer in the kernel Tensor.
546 if (address_common_->device_name_ == kCPUDevice) {
547 if (!host_info_->kernel_tensor_value_) {
548 host_info_->kernel_tensor_value_ = std::make_shared<KernelTensorValue>(device_ptr, address_common_->size_, type_);
549 } else {
550 host_info_->kernel_tensor_value_->SetDataPtr(device_ptr);
551 host_info_->kernel_tensor_value_->Resize(address_common_->size_);
552 }
553 return true;
554 }
555
556 if (!host_info_->kernel_tensor_value_) {
557 host_info_->kernel_tensor_value_ = std::make_shared<KernelTensorValue>(address_common_->size_, type_);
558 } else {
559 host_info_->kernel_tensor_value_->Resize(address_common_->size_);
560 }
561
562 if (address_common_->size_ == 0) {
563 return true;
564 }
565
566 void *host_ptr = host_info_->kernel_tensor_value_->GetMutableDataPtr();
567 MS_EXCEPTION_IF_NULL(host_ptr);
568
569 MS_EXCEPTION_IF_NULL(device_synchronizer_);
570 if (!device_synchronizer_->SyncDeviceToHost(
571 host_ptr, device_ptr, address_common_->size_, address_common_->device_name_, address_common_->device_id_,
572 address_common_->format_, address_common_->shape_vector_, address_common_->stream_id_, user_data_)) {
573 MS_LOG(EXCEPTION) << "Sync data from device to host side failed";
574 }
575 return true;
576 }
577
IsDynamicShape() const578 bool KernelTensor::IsDynamicShape() const {
579 const auto &shape = this->GetShapeVector();
580 return std::any_of(shape.cbegin(), shape.cend(), [](auto i) { return i < 0; });
581 }
582
GetMaxShape() const583 ShapeVector KernelTensor::GetMaxShape() const {
584 MS_EXCEPTION_IF_NULL(host_info_);
585 if (host_info_->type_id_ != kObjectTypeTensorType) {
586 return {};
587 }
588 if (shape_ == nullptr || !shape_->isa<abstract::Shape>()) {
589 return {};
590 }
591
592 return shape_->cast<abstract::ShapePtr>()->max_shape();
593 }
594
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)595 int KernelMod::Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
596 auto ret = KRET_OK;
597 workspace_size_list_.clear();
598 output_size_list_.clear();
599
600 for (size_t idx = 0; idx < outputs.size(); idx++) {
601 auto &output = outputs[idx];
602 size_t tensor_size = 0;
603 MS_EXCEPTION_IF_NULL(output);
604 size_t type_size = UnitSizeInBytes(output->dtype_id());
605 if (type_size == 0) {
606 MS_LOG(WARNING) << "The type size is 0, type: " << TypeIdToType(output->dtype_id())->ToString();
607 }
608
609 const auto &shape = output->GetShapeVector();
610 if (!IsValidShape(shape)) {
611 MS_LOG(WARNING) << "Invalid shape:" << mindspore::ToString(shape) << ", kernel name:" << kernel_name();
612 // Note:
613 // If output shape is unknown, the op is a compute-depended op, and the output_size_list_ can be set by default
614 // size: type_size.
615 tensor_size = type_size;
616 ret = KRET_UNKNOWN_OUT_SHAPE;
617 } else {
618 if (shape.empty()) {
619 tensor_size = type_size;
620 } else {
621 auto cur_out_shape_num = SizeOf(shape);
622 tensor_size = cur_out_shape_num * type_size;
623 if (type_size != 0 && tensor_size / type_size != cur_out_shape_num) {
624 MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", the shape of outputs[" << output_size_list_.size()
625 << "]: " << shape
626 << " is too big, mindspore cannot apply for such a large amount of memory.";
627 }
628 }
629 }
630 (void)output_size_list_.emplace_back(tensor_size);
631 }
632 return static_cast<int>(ret);
633 }
634
GetShapes(const std::vector<KernelTensor * > & tensors)635 std::vector<std::vector<int64_t>> GetShapes(const std::vector<KernelTensor *> &tensors) {
636 std::vector<std::vector<int64_t>> shapes(tensors.size());
637 for (size_t idx = 0; idx < shapes.size(); idx++) {
638 shapes[idx] = tensors[idx]->GetShapeVector();
639 }
640 return shapes;
641 }
642
ConvertLaunchInfoToAddr(const KernelLaunchInfo & launch_info,KernelLaunchAddr * mem_info)643 void ConvertLaunchInfoToAddr(const KernelLaunchInfo &launch_info, KernelLaunchAddr *mem_info) {
644 (mem_info->inputs_).clear();
645 (mem_info->outputs_).clear();
646 (mem_info->workspaces_).clear();
647 std::transform((launch_info.inputs_).begin(), (launch_info.inputs_).end(), std::back_inserter(mem_info->inputs_),
648 [](const auto &input) { return std::make_shared<Address>(input->device_ptr(), input->size()); });
649 std::transform(
650 (launch_info.workspaces_).begin(), (launch_info.workspaces_).end(), std::back_inserter(mem_info->workspaces_),
651 [](const auto &workspace) { return std::make_shared<Address>(workspace->device_ptr(), workspace->size()); });
652 std::transform((launch_info.outputs_).begin(), (launch_info.outputs_).end(), std::back_inserter(mem_info->outputs_),
653 [](const auto &output) { return std::make_shared<Address>(output->device_ptr(), output->size()); });
654 }
655 } // namespace kernel
656 } // namespace mindspore
657