• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "kernel/kernel.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <numeric>
23 #include <set>
24 
25 #include "kernel/format_utils.h"
26 #include "kernel/common_utils.h"
27 #include "utils/ms_context.h"
28 #include "include/backend/device_synchronizer_utils.h"
29 
30 namespace mindspore {
31 namespace kernel {
32 constexpr int64_t kInvalidShape = -2;
33 
34 namespace {
35 using ShapeTransposeFunc = std::function<void(const ShapeVector *, ShapeVector *)>;
36 
TransposeDefaultShape(const ShapeVector * host_shape_vector,ShapeVector * device_shape_vector)37 void TransposeDefaultShape(const ShapeVector *host_shape_vector, ShapeVector *device_shape_vector) {
38   MS_EXCEPTION_IF_NULL(host_shape_vector);
39   MS_EXCEPTION_IF_NULL(device_shape_vector);
40   *device_shape_vector = *host_shape_vector;
41 }
42 
TransposeNCHWShape(const ShapeVector * host_shape_vector,ShapeVector * device_shape_vector)43 void TransposeNCHWShape(const ShapeVector *host_shape_vector, ShapeVector *device_shape_vector) {
44   MS_EXCEPTION_IF_NULL(host_shape_vector);
45   MS_EXCEPTION_IF_NULL(device_shape_vector);
46   if (host_shape_vector->size() != kDim4) {
47     MS_LOG(EXCEPTION) << "The host shape dims should be 4, but got: " << host_shape_vector->size();
48   }
49   *device_shape_vector = *host_shape_vector;
50 }
51 
TransposeNHWCShape(const ShapeVector * host_shape_vector,ShapeVector * device_shape_vector)52 void TransposeNHWCShape(const ShapeVector *host_shape_vector, ShapeVector *device_shape_vector) {
53   MS_EXCEPTION_IF_NULL(host_shape_vector);
54   MS_EXCEPTION_IF_NULL(device_shape_vector);
55 
56   if (host_shape_vector->size() != kDim4) {
57     MS_LOG(EXCEPTION) << "The host shape dims should be 4, but got: " << host_shape_vector->size();
58   }
59   device_shape_vector->resize(kDim4);
60 
61   device_shape_vector->at(kIndex0) = host_shape_vector->at(kIndex0);
62   device_shape_vector->at(kIndex1) = host_shape_vector->at(kIndex2);
63   device_shape_vector->at(kIndex2) = host_shape_vector->at(kIndex3);
64   device_shape_vector->at(kIndex3) = host_shape_vector->at(kIndex1);
65 }
66 }  // namespace
67 
KernelHostInfo(const KernelHostInfo & other)68 KernelHostInfo::KernelHostInfo(const KernelHostInfo &other) {
69   shape_vector_after_format_trasform_ = other.shape_vector_after_format_trasform_;
70   type_id_ = other.type_id_;
71   kernel_tensor_value_ = other.kernel_tensor_value_;
72 }
73 
KernelTensor()74 KernelTensor::KernelTensor() { address_common_ = std::make_shared<AddressCommon>(); }
75 
KernelTensor(const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value)76 KernelTensor::KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value) {
77   host_info_ = std::make_unique<KernelHostInfo>();
78   address_common_ = std::make_shared<AddressCommon>();
79 
80   if (type) {
81     SetType(type);
82   }
83   if (shape) {
84     // Note: for performance, the function `SetShape` uses host_info_->type_id_, so need to SetType first.
85     SetShape(shape);
86   }
87   if (value) {
88     SetValue(value);
89   }
90 }
91 
KernelTensor(void * device_ptr,size_t size,Format format,TypeId dtype_id,const ShapeVector & host_shape,const string & device_name,uint32_t device_id,const UserDataPtr & user_data)92 KernelTensor::KernelTensor(void *device_ptr, size_t size, Format format, TypeId dtype_id, const ShapeVector &host_shape,
93                            const string &device_name, uint32_t device_id, const UserDataPtr &user_data)
94     : host_shape_(host_shape),
95       user_data_(user_data),
96       address_common_(
97         std::make_shared<AddressCommon>(device_ptr, size, host_shape, format, dtype_id, device_name, device_id)) {
98   if (dtype_id == kTypeUnknown) {
99     SetType(TypeIdToType(dtype_id));
100   } else {
101     SetType(std::make_shared<TensorType>(TypeIdToType(dtype_id)));
102   }
103 }
104 
KernelTensor(const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value,void * device_ptr,size_t size,const std::string & format,TypeId dtype_id,const ShapeVector & host_shape,const string & device_name,uint32_t device_id,const UserDataPtr & user_data)105 KernelTensor::KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value,
106                            void *device_ptr, size_t size, const std::string &format, TypeId dtype_id,
107                            const ShapeVector &host_shape, const string &device_name, uint32_t device_id,
108                            const UserDataPtr &user_data)
109     : KernelTensor(shape, type, value) {
110   address_common_->pointer_ref_count_->set_ptr(device_ptr);
111   address_common_->size_ = size;
112   address_common_->format_ = GetFormatFromStrToEnum(format);
113   address_common_->dtype_id_ = dtype_id;
114   address_common_->device_name_ = device_name;
115   address_common_->device_id_ = device_id;
116   host_shape_ = host_shape;
117   user_data_ = user_data;
118 }
119 
KernelTensor(const AddressCommonPtr & address_common,const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value,const ShapeVector & host_shape,const UserDataPtr & user_data)120 KernelTensor::KernelTensor(const AddressCommonPtr &address_common, const abstract::BaseShapePtr &shape,
121                            const TypePtr &type, const ValuePtr &value, const ShapeVector &host_shape,
122                            const UserDataPtr &user_data)
123     : KernelTensor(shape, type, value) {
124   address_common_ = address_common;
125   host_shape_ = host_shape;
126   user_data_ = user_data;
127 }
128 
KernelTensor(const KernelTensor & other)129 KernelTensor::KernelTensor(const KernelTensor &other) {
130   // Copy host info.
131   shape_ = other.shape_ != nullptr ? other.shape_->Clone() : abstract::kNoShape;
132   type_ = other.shape_ != nullptr ? other.type_->Clone() : kTypeAny;
133   value_ = other.value_;
134 
135   if (other.host_info_) {
136     host_info_ = std::make_unique<KernelHostInfo>(*other.host_info_);
137     host_info_->kernel_tensor_value_ = other.host_info_->kernel_tensor_value_ != nullptr
138                                          ? std::make_shared<KernelTensorValue>(*other.host_info_->kernel_tensor_value_)
139                                          : nullptr;
140   }
141 
142   // Copy device info.
143   task_id_on_stream_ = other.task_id_on_stream_;
144   address_common_ = std::make_shared<AddressCommon>(*other.address_common_);
145   device_synchronizer_ = other.device_synchronizer_;
146   host_shape_ = other.host_shape_;
147   user_data_ = other.user_data_;
148 }
149 
CheckHostInfoValid()150 inline void KernelTensor::CheckHostInfoValid() {
151   if (MS_UNLIKELY(!host_info_)) {
152     host_info_ = std::make_unique<KernelHostInfo>();
153   }
154 }
155 namespace {
GetShapeVectorByBaseShape(const abstract::BaseShapePtr & base_shape)156 ShapeVector GetShapeVectorByBaseShape(const abstract::BaseShapePtr &base_shape) {
157   MS_EXCEPTION_IF_NULL(base_shape);
158   if (base_shape->isa<abstract::NoShape>()) {
159     return {};
160   } else if (base_shape->isa<abstract::Shape>()) {
161     return base_shape->cast<abstract::ShapePtr>()->shape();
162   } else if (base_shape->isa<abstract::DynamicSequenceShape>()) {
163     return {-1};
164   } else if (base_shape->isa<abstract::SequenceShape>()) {
165     const auto &sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
166     MS_EXCEPTION_IF_NULL(sequence_shape);
167     if (sequence_shape->size() == 0) {
168       return {0};
169     }
170     ShapeVector shape_vector = {SizeToLong(sequence_shape->size())};
171     const auto &sub_shape_vector = GetShapeVectorByBaseShape(sequence_shape->shape()[0]);
172     shape_vector.insert(shape_vector.end(), sub_shape_vector.begin(), sub_shape_vector.end());
173     return shape_vector;
174   }
175   MS_LOG(EXCEPTION) << "Invalid shape:" << base_shape->ToString();
176 }
177 }  // namespace
178 
SetHostInfo(const abstract::BaseShapePtr & shape,const TypePtr & type,const ValuePtr & value)179 void KernelTensor::SetHostInfo(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value) {
180   CheckHostInfoValid();
181   if (type) {
182     SetType(type);
183   }
184   if (shape) {
185     SetShape(shape);
186   }
187   if (value) {
188     SetValue(value);
189   }
190 }
191 
SetShape(const abstract::BaseShapePtr & shape)192 void KernelTensor::SetShape(const abstract::BaseShapePtr &shape) {
193   MS_EXCEPTION_IF_NULL(shape);
194   shape_ = shape;
195   CheckHostInfoValid();
196 
197   // Note: for performance, the function `SetShape` uses host_info_->type_id_, so need to SetType first.
198   switch (host_info_->type_id_) {
199     case kObjectTypeMapTensorType:
200     case kObjectTypeTensorType: {
201       // The shape type check will affect the performance. The following check will be deleted after the framework is
202       // stable.
203       if (shape_->isa<abstract::NoShape>()) {
204         address_common_->shape_vector_ = {};
205       } else {
206         if (!shape_->isa<abstract::TensorShape>()) {
207           MS_LOG(EXCEPTION) << "Expected TensorShape for SetShape, but got: " << shape_->type_name() << ", "
208                             << shape_->ToString();
209         }
210         address_common_->shape_vector_ = shape_->GetShapeVector();
211       }
212 
213       break;
214     }
215 
216     case kObjectTypeList:
217     case kObjectTypeTuple: {
218       if (shape->isa<abstract::DynamicSequenceShape>()) {
219         address_common_->shape_vector_ = {-1};
220         break;
221       }
222       const auto &seq_shape = shape_->cast<abstract::SequenceShapePtr>();
223       if (seq_shape == nullptr) {
224         MS_LOG(EXCEPTION) << "Expected SequenceShape for SetShape, but got: " << shape_->type_name() << ", "
225                           << shape_->ToString();
226       }
227       address_common_->shape_vector_.clear();
228       address_common_->shape_vector_.push_back(seq_shape->size());
229       const auto &shapes = seq_shape->shape();
230       if (shapes.empty()) {
231         break;
232       }
233       const auto &element_shape = shapes[0];
234       MS_EXCEPTION_IF_NULL(element_shape);
235       if (element_shape->isa<abstract::TensorShape>()) {
236         const ShapeVector &element_shape_vector = element_shape->GetShapeVector();
237         address_common_->shape_vector_.insert(address_common_->shape_vector_.end(), element_shape_vector.begin(),
238                                               element_shape_vector.end());
239       } else if (element_shape->isa<abstract::SequenceShape>()) {
240         const ShapeVector &element_shape_vector = GetShapeVectorByBaseShape(element_shape);
241         address_common_->shape_vector_.insert(address_common_->shape_vector_.end(), element_shape_vector.begin(),
242                                               element_shape_vector.end());
243       }
244 
245       break;
246     }
247 
248     case kTypeUnknown: {
249       MS_LOG(EXCEPTION) << "Can not set shape for unknown type, please set correct type for kernel tensor first.";
250     }
251 
252     default:
253       MS_EXCEPTION_IF_NULL(type_);
254       MS_LOG(DEBUG) << "Need not set shape for: " << type_->ToString();
255   }
256 
257   // Update size_ after shape changed.
258   // Note: calculate memory size should be executed after 'SetType' and 'SetShape'.
259   CalculateMemSize();
260 }
261 
CalculateMemSize()262 void KernelTensor::CalculateMemSize() {
263   MS_EXCEPTION_IF_NULL(host_info_);
264   if (host_info_->type_id_ == kObjectTypeTensorType || host_info_->type_id_ == kObjectTypeTuple ||
265       host_info_->type_id_ == kObjectTypeList) {
266     // If address_common_->shape_vector_ is a dynamic shape, device_info_->size_ will be 0.
267     size_t element_num = SizeOf(address_common_->shape_vector_);
268     address_common_->size_ = element_num * UnitSizeInBytes(address_common_->dtype_id_);
269   } else if (host_info_->type_id_ == kObjectTypeNumber) {
270     address_common_->size_ = UnitSizeInBytes(address_common_->dtype_id_);
271   }
272 }
273 
SetShapeVector(const ShapeVector & shape_vector)274 void KernelTensor::SetShapeVector(const ShapeVector &shape_vector) {
275   CheckHostInfoValid();
276   if (host_info_->type_id_ == kObjectTypeTensorType || host_info_->type_id_ == kObjectTypeMapTensorType) {
277     address_common_->shape_vector_ = shape_vector;
278     MS_EXCEPTION_IF_NULL(shape_);
279     shape_->SetShapeVector(address_common_->shape_vector_);
280 
281     MS_LOG(DEBUG) << "Set shape vector: " << shape_vector
282                   << ", the format: " << GetFormatFromEnumToStr(address_common_->format_);
283     return;
284   }
285 
286   if (host_info_->type_id_ == kObjectTypeNumber) {
287     if (!shape_vector.empty()) {
288       MS_LOG(EXCEPTION) << "For Number Type, shape should be empty, but got " << shape_vector;
289     }
290     return;
291   }
292 
293   MS_LOG(EXCEPTION) << "Only support Scalar/Tensor/MapTensor type to set shape vector currently, but got type: "
294                     << TypeIdLabel(host_info_->type_id_);
295 }
296 
SetShapeVector(ShapeVector && shape_vector)297 void KernelTensor::SetShapeVector(ShapeVector &&shape_vector) {
298   CheckHostInfoValid();
299   if (host_info_->type_id_ == kObjectTypeTensorType || host_info_->type_id_ == kObjectTypeMapTensorType) {
300     address_common_->shape_vector_ = std::move(shape_vector);
301     MS_EXCEPTION_IF_NULL(shape_);
302     shape_->SetShapeVector(address_common_->shape_vector_);
303 
304     MS_LOG(DEBUG) << "Set shape vector: " << shape_vector
305                   << ", the format: " << GetFormatFromEnumToStr(address_common_->format_);
306     return;
307   }
308 
309   if (host_info_->type_id_ == kObjectTypeNumber) {
310     if (!shape_vector.empty()) {
311       MS_LOG(EXCEPTION) << "For Number Type, shape should be empty, but got " << shape_vector;
312     }
313     return;
314   }
315 
316   MS_LOG(EXCEPTION) << "Only support Scalar/Tensor/MapTensor type to set shape vector currently, but got type: "
317                     << TypeIdLabel(host_info_->type_id_);
318 }
319 
TransposeToDeviceShape() const320 const ShapeVector &KernelTensor::TransposeToDeviceShape() const {
321   MS_EXCEPTION_IF_NULL(host_info_);
322   if (host_info_->type_id_ != kObjectTypeTensorType) {
323     MS_LOG(EXCEPTION) << "Only TensorType could transpose device shape, but got: " << TypeIdLabel(host_info_->type_id_);
324   }
325 
326   static const mindspore::HashMap<mindspore::Format, ShapeTransposeFunc> shape_trans_funcs = {
327     {Format::DEFAULT_FORMAT, TransposeDefaultShape},
328     {Format::NCHW, TransposeNCHWShape},
329     {Format::NHWC, TransposeNHWCShape}};
330 
331   auto iter = shape_trans_funcs.find(address_common_->format_);
332   if (iter == shape_trans_funcs.end()) {
333     MS_LOG(EXCEPTION) << "Can not find shape transpose function for format: "
334                       << GetFormatFromEnumToStr(address_common_->format_);
335   }
336 
337   // The shape of the device corresponding to 'address_common_->shape_vector_'. For example, if format is NHWC, the
338   // shape of the device and host may be different.
339   iter->second(&address_common_->shape_vector_, &host_info_->shape_vector_after_format_trasform_);
340   return host_info_->shape_vector_after_format_trasform_;
341 }
342 
NeedTransposeToDeviceShape() const343 bool KernelTensor::NeedTransposeToDeviceShape() const noexcept {
344   static std::set<mindspore::Format> black_list{Format::DEFAULT_FORMAT, Format::NCHW, Format::ND, Format::NCDHW};
345   auto it = black_list.find(address_common_->format_);
346   return it == black_list.end();
347 }
348 
GetDeviceShapeVector() const349 const ShapeVector &KernelTensor::GetDeviceShapeVector() const {
350   MS_EXCEPTION_IF_NULL(host_info_);
351   if (NeedTransposeToDeviceShape()) {
352     std::lock_guard<std::mutex> lock(host_info_->shape_transform_mutex_);
353     return TransposeToDeviceShape();
354   }
355   return address_common_->shape_vector_;
356 }
357 
SetType(const TypePtr & type)358 void KernelTensor::SetType(const TypePtr &type) {
359   MS_EXCEPTION_IF_NULL(type);
360   CheckHostInfoValid();
361   type_ = type;
362   host_info_->type_id_ = type_->object_type();
363   if (host_info_->type_id_ == kTypeUnknown) {
364     host_info_->type_id_ = type_->type_id();
365     MS_EXCEPTION_IF_CHECK_FAIL((host_info_->type_id_ != kTypeUnknown),
366                                "Got a unknown type id, type info: " + type_->ToString());
367   }
368 
369   switch (host_info_->type_id_) {
370     case kObjectTypeTensorType: {
371       auto tensor_type_ptr = type_->cast<TensorTypePtr>();
372       MS_EXCEPTION_IF_NULL(tensor_type_ptr);
373       auto element_type = tensor_type_ptr->element();
374       if (element_type) {
375         address_common_->dtype_id_ = element_type->type_id();
376       }
377     } break;
378 
379     case kObjectTypeTuple: {
380       auto tuple_type = type_->cast<TuplePtr>();
381       MS_EXCEPTION_IF_NULL(tuple_type);
382       TypePtr element_type = nullptr;
383       if (tuple_type->dynamic_len()) {
384         element_type = tuple_type->dynamic_element_type();
385         if (element_type == nullptr) {
386           return;
387         }
388       } else {
389         const TypePtrList &element_types = tuple_type->elements();
390         if (element_types.empty()) {
391           return;
392         }
393         element_type = element_types[0];
394       }
395       SetSequenceDType(element_type);
396     } break;
397 
398     case kObjectTypeList: {
399       auto list_type = type_->cast<ListPtr>();
400       MS_EXCEPTION_IF_NULL(list_type);
401       TypePtr element_type = nullptr;
402       if (list_type->dynamic_len()) {
403         element_type = list_type->dynamic_element_type();
404         if (element_type == nullptr) {
405           return;
406         }
407       } else {
408         const TypePtrList &element_types = list_type->elements();
409         if (element_types.empty()) {
410           return;
411         }
412         element_type = element_types[0];
413       }
414       SetSequenceDType(element_type);
415     } break;
416 
417     default:
418       address_common_->dtype_id_ = type->type_id();
419       MS_LOG(DEBUG) << "Set dtype for: " << type->ToString();
420   }
421 }
422 
SetSequenceDType(const TypePtr & element_type)423 void KernelTensor::SetSequenceDType(const TypePtr &element_type) {
424   MS_EXCEPTION_IF_NULL(element_type);
425   if (element_type->object_type() == kObjectTypeTensorType) {
426     // Tensor type element.
427     auto tensor_type_ptr = element_type->cast<TensorTypePtr>();
428     MS_EXCEPTION_IF_NULL(tensor_type_ptr);
429     auto tensor_element_type = tensor_type_ptr->element();
430     if (tensor_element_type) {
431       address_common_->dtype_id_ = tensor_element_type->type_id();
432     }
433   } else if (element_type->object_type() == kObjectTypeNumber) {
434     // Scalar type element.
435     address_common_->dtype_id_ = element_type->type_id();
436   } else if (element_type->object_type() == kObjectTypeString) {
437     // String type element.
438     address_common_->dtype_id_ = element_type->type_id();
439   } else if (element_type->object_type() == kObjectTypeTuple) {
440     // Sequence type element.
441     auto tuple_type = element_type->cast<TuplePtr>();
442     MS_EXCEPTION_IF_NULL(tuple_type);
443     if (tuple_type->dynamic_len()) {
444       if (tuple_type->dynamic_element_type() == nullptr) {
445         return;
446       }
447       SetSequenceDType(tuple_type->dynamic_element_type());
448       return;
449     }
450     const TypePtrList &element_types = tuple_type->elements();
451     if (element_types.empty() || element_types[0] == nullptr) {
452       return;
453     }
454     SetSequenceDType(element_types[0]);
455     return;
456   } else if (element_type->object_type() == kObjectTypeList) {
457     // Sequence type element.
458     auto list_type = element_type->cast<ListPtr>();
459     MS_EXCEPTION_IF_NULL(list_type);
460     if (list_type->dynamic_len()) {
461       if (list_type->dynamic_element_type() == nullptr) {
462         return;
463       }
464       SetSequenceDType(list_type->dynamic_element_type());
465       return;
466     }
467     const TypePtrList &element_types = list_type->elements();
468     if (element_types.empty() || element_types[0] == nullptr) {
469       return;
470     }
471     SetSequenceDType(element_types[0]);
472     return;
473   } else {
474     MS_LOG(EXCEPTION) << "Unsupported element type[" << element_type->ToString()
475                       << "] to set element data type for KernelTensor.";
476   }
477 }
478 
GetStringFormat() const479 std::string KernelTensor::GetStringFormat() const { return GetFormatFromEnumToStr(address_common_->format_); }
480 
SetStringFormat(const std::string & format)481 void KernelTensor::SetStringFormat(const std::string &format) {
482   address_common_->format_ = GetFormatFromStrToEnum(format);
483 }
484 
GetValue() const485 ValuePtr KernelTensor::GetValue() const {
486   MS_EXCEPTION_IF_NULL(host_info_);
487   std::lock_guard<std::mutex> lock(host_info_->value_mutex_);
488 
489   // There is a origin value in KernelTensor(maybe come from a ValueNode).
490   if (address_common_->dtype_id_ == kMetaTypeNone) {
491     return kNone;
492   } else if (value_ && !value_->isa<ValueAny>()) {
493     if (host_info_->kernel_tensor_value_ == nullptr) {
494       host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_);
495       return host_info_->kernel_tensor_value_ ? host_info_->kernel_tensor_value_ : value_;
496     }
497     return host_info_->kernel_tensor_value_;
498   }
499 
500   // Sync value data from device.
501   if (!SyncDataFromDeviceToHost()) {
502     MS_LOG(EXCEPTION) << "Sync data from device to host side failed";
503   }
504   return host_info_->kernel_tensor_value_;
505 }
506 
GetValuePtr()507 const void *KernelTensor::GetValuePtr() {
508   CheckHostInfoValid();
509   std::lock_guard<std::mutex> lock(host_info_->value_mutex_);
510 
511   // There is a origin value in KernelTensor(maybe come from a ValueNode).
512   if (address_common_->dtype_id_ == kMetaTypeNone) {
513     return nullptr;
514   } else if (value_ && !value_->isa<ValueAny>()) {
515     if (host_info_->kernel_tensor_value_ == nullptr) {
516       host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_);
517     }
518     MS_EXCEPTION_IF_NULL(host_info_->kernel_tensor_value_);
519     return host_info_->kernel_tensor_value_->GetDataPtr();
520   }
521 
522   // Sync value data from device.
523   if (!SyncDataFromDeviceToHost()) {
524     MS_LOG(EXCEPTION) << "Sync data from device to host side failed";
525   }
526   return host_info_->kernel_tensor_value_->GetDataPtr();
527 }
528 
SyncDataFromDeviceToHost() const529 bool KernelTensor::SyncDataFromDeviceToHost() const {
530   // Note: must release lock when wait async resize or launch kernel finish, because the kernels' resize and launch
531   // tasks which are waited maybe use this kernel's GetValue and try lock this mutex to avoid deadlock.
532   host_info_->value_mutex_.unlock();
533   WaitAsyncResizeAndLaunchFinish();
534   host_info_->value_mutex_.lock();
535 
536   void *device_ptr = this->device_ptr();
537   if (device_ptr == nullptr) {
538     MS_LOG(ERROR) << "Not malloc device memory yet, sync data from device to host side failed, size: "
539                   << address_common_->size_;
540     return false;
541   }
542 
543   MS_EXCEPTION_IF_NULL(host_info_);
544   // For performance, the CPU back-end does not need to copy the device to host, and directly uses the
545   // device pointer in the kernel Tensor.
546   if (address_common_->device_name_ == kCPUDevice) {
547     if (!host_info_->kernel_tensor_value_) {
548       host_info_->kernel_tensor_value_ = std::make_shared<KernelTensorValue>(device_ptr, address_common_->size_, type_);
549     } else {
550       host_info_->kernel_tensor_value_->SetDataPtr(device_ptr);
551       host_info_->kernel_tensor_value_->Resize(address_common_->size_);
552     }
553     return true;
554   }
555 
556   if (!host_info_->kernel_tensor_value_) {
557     host_info_->kernel_tensor_value_ = std::make_shared<KernelTensorValue>(address_common_->size_, type_);
558   } else {
559     host_info_->kernel_tensor_value_->Resize(address_common_->size_);
560   }
561 
562   if (address_common_->size_ == 0) {
563     return true;
564   }
565 
566   void *host_ptr = host_info_->kernel_tensor_value_->GetMutableDataPtr();
567   MS_EXCEPTION_IF_NULL(host_ptr);
568 
569   MS_EXCEPTION_IF_NULL(device_synchronizer_);
570   if (!device_synchronizer_->SyncDeviceToHost(
571         host_ptr, device_ptr, address_common_->size_, address_common_->device_name_, address_common_->device_id_,
572         address_common_->format_, address_common_->shape_vector_, address_common_->stream_id_, user_data_)) {
573     MS_LOG(EXCEPTION) << "Sync data from device to host side failed";
574   }
575   return true;
576 }
577 
IsDynamicShape() const578 bool KernelTensor::IsDynamicShape() const {
579   const auto &shape = this->GetShapeVector();
580   return std::any_of(shape.cbegin(), shape.cend(), [](auto i) { return i < 0; });
581 }
582 
GetMaxShape() const583 ShapeVector KernelTensor::GetMaxShape() const {
584   MS_EXCEPTION_IF_NULL(host_info_);
585   if (host_info_->type_id_ != kObjectTypeTensorType) {
586     return {};
587   }
588   if (shape_ == nullptr || !shape_->isa<abstract::Shape>()) {
589     return {};
590   }
591 
592   return shape_->cast<abstract::ShapePtr>()->max_shape();
593 }
594 
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)595 int KernelMod::Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
596   auto ret = KRET_OK;
597   workspace_size_list_.clear();
598   output_size_list_.clear();
599 
600   for (size_t idx = 0; idx < outputs.size(); idx++) {
601     auto &output = outputs[idx];
602     size_t tensor_size = 0;
603     MS_EXCEPTION_IF_NULL(output);
604     size_t type_size = UnitSizeInBytes(output->dtype_id());
605     if (type_size == 0) {
606       MS_LOG(WARNING) << "The type size is 0, type: " << TypeIdToType(output->dtype_id())->ToString();
607     }
608 
609     const auto &shape = output->GetShapeVector();
610     if (!IsValidShape(shape)) {
611       MS_LOG(WARNING) << "Invalid shape:" << mindspore::ToString(shape) << ", kernel name:" << kernel_name();
612       // Note:
613       // If output shape is unknown, the op is a compute-depended op, and the output_size_list_ can be set by default
614       // size: type_size.
615       tensor_size = type_size;
616       ret = KRET_UNKNOWN_OUT_SHAPE;
617     } else {
618       if (shape.empty()) {
619         tensor_size = type_size;
620       } else {
621         auto cur_out_shape_num = SizeOf(shape);
622         tensor_size = cur_out_shape_num * type_size;
623         if (type_size != 0 && tensor_size / type_size != cur_out_shape_num) {
624           MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", the shape of outputs[" << output_size_list_.size()
625                                    << "]: " << shape
626                                    << " is too big, mindspore cannot apply for such a large amount of memory.";
627         }
628       }
629     }
630     (void)output_size_list_.emplace_back(tensor_size);
631   }
632   return static_cast<int>(ret);
633 }
634 
GetShapes(const std::vector<KernelTensor * > & tensors)635 std::vector<std::vector<int64_t>> GetShapes(const std::vector<KernelTensor *> &tensors) {
636   std::vector<std::vector<int64_t>> shapes(tensors.size());
637   for (size_t idx = 0; idx < shapes.size(); idx++) {
638     shapes[idx] = tensors[idx]->GetShapeVector();
639   }
640   return shapes;
641 }
642 
ConvertLaunchInfoToAddr(const KernelLaunchInfo & launch_info,KernelLaunchAddr * mem_info)643 void ConvertLaunchInfoToAddr(const KernelLaunchInfo &launch_info, KernelLaunchAddr *mem_info) {
644   (mem_info->inputs_).clear();
645   (mem_info->outputs_).clear();
646   (mem_info->workspaces_).clear();
647   std::transform((launch_info.inputs_).begin(), (launch_info.inputs_).end(), std::back_inserter(mem_info->inputs_),
648                  [](const auto &input) { return std::make_shared<Address>(input->device_ptr(), input->size()); });
649   std::transform(
650     (launch_info.workspaces_).begin(), (launch_info.workspaces_).end(), std::back_inserter(mem_info->workspaces_),
651     [](const auto &workspace) { return std::make_shared<Address>(workspace->device_ptr(), workspace->size()); });
652   std::transform((launch_info.outputs_).begin(), (launch_info.outputs_).end(), std::back_inserter(mem_info->outputs_),
653                  [](const auto &output) { return std::make_shared<Address>(output->device_ptr(), output->size()); });
654 }
655 }  // namespace kernel
656 }  // namespace mindspore
657