• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pynative/op_runner.h"
18 
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <map>
23 #include <unordered_map>
24 #include <algorithm>
25 #include <array>
26 #include "ops/structure_op_name.h"
27 #include "utils/log_adapter.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "include/backend/optimizer/helper.h"
30 #include "include/backend/device_type.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "runtime/device/ms_device_shape_transfer.h"
33 #include "runtime/device/device_address_utils.h"
34 #include "runtime/pynative/op_runtime_info.h"
35 #include "runtime/pynative/op_executor.h"
36 #include "runtime/pynative/op_compiler.h"
37 #include "runtime/graph_scheduler/actor/actor_common.h"
38 #include "kernel/framework_utils.h"
39 #include "include/backend/mem_reuse/mem_tracker.h"
40 #ifndef ENABLE_SECURITY
41 #include "include/backend/debug/profiler/profiling.h"
42 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
43 #include "pybind_api/gil_scoped_long_running.h"
44 #include "runtime/pynative/ir_converter.h"
45 
46 using mindspore::profiler::ProfilerManager;
47 #endif
48 using EdgePtr = mindspore::pynative::EdgePtr;
49 
50 namespace mindspore::runtime {
51 namespace {
52 constexpr size_t kContextSize = 4;
53 std::unique_ptr<std::mutex> kDeviceContextMutex = std::make_unique<std::mutex>();
54 std::array<DeviceContext *, kContextSize> kDeviceContexts = {nullptr, nullptr, nullptr, nullptr};
55 
56 // 1. Device type is different in heterogeneous scenes.
57 // 2. The device address format is different.
UpdateInputTensorFromDevice(const std::vector<AnfNodePtr> & input_nodes,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context)58 void UpdateInputTensorFromDevice(const std::vector<AnfNodePtr> &input_nodes,
59                                  const std::vector<tensor::BaseTensorPtr> &input_tensors,
60                                  const device::DeviceContext *device_context) {
61   MS_LOG(DEBUG) << "Start";
62   auto input_size = input_nodes.size();
63   for (size_t i = 0; i < input_size; ++i) {
64     auto &tensor = input_tensors[i];
65     auto &input_node = input_nodes[i];
66     MS_EXCEPTION_IF_NULL(tensor);
67     auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
68     auto node_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
69     // node_address can't be null
70     MS_EXCEPTION_IF_NULL(node_address);
71     MS_EXCEPTION_IF_NULL(device_context);
72     if (tensor_address != nullptr) {
73       if (tensor_address->GetDeviceType() != device_context->GetDeviceType() ||
74           tensor_address->format() != node_address->format()) {
75         // Need wait for OpExecutor task finish
76         tensor->data_sync();
77         // If tensor address is null, we will set Parameter address to the Tensor.
78         tensor->set_device_address(nullptr);
79       }
80     }
81   }
82   MS_LOG(DEBUG) << "End";
83 }
84 
UpdateParameterShapeFromInputTensor(const AnfNodePtr & input_node,const tensor::BaseTensorPtr & input_tensor)85 void UpdateParameterShapeFromInputTensor(const AnfNodePtr &input_node, const tensor::BaseTensorPtr &input_tensor) {
86   MS_EXCEPTION_IF_NULL(input_node);
87   if (input_tensor == nullptr || !input_node->isa<Parameter>()) {
88     return;
89   }
90 
91   auto input_param = input_node->cast<ParameterPtr>();
92   MS_EXCEPTION_IF_NULL(input_param);
93   if (!input_param->has_dynamic_shape()) {
94     return;
95   }
96 
97   auto shape = input_tensor->shape();
98   MS_LOG(DEBUG) << "Update input node shape to:" << shape;
99   common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape},
100                                               input_node.get());
101 }
102 
SetDeviceAddress(const AnfNodePtr & input_node,const tensor::BaseTensorPtr & input_tensor,const device::DeviceContext * device_context,bool is_sync)103 void SetDeviceAddress(const AnfNodePtr &input_node, const tensor::BaseTensorPtr &input_tensor,
104                       const device::DeviceContext *device_context, bool is_sync) {
105   MS_EXCEPTION_IF_NULL(input_tensor);
106   auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
107   auto node_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
108 
109   UpdateParameterShapeFromInputTensor(input_node, input_tensor);
110 
111   MS_EXCEPTION_IF_NULL(node_address);
112   if (tensor_address == nullptr) {
113     input_tensor->set_device_address(node_address);
114     input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
115     input_tensor->set_need_pipeline_sync(true);
116     node_address->set_from_persistent_mem(input_tensor->is_parameter());
117     node_address->SetNodeIndex(input_node, 0);
118   }
119 
120   // The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
121   if (tensor_address != nullptr && tensor_address != node_address) {
122     auto address = tensor_address;
123     if (tensor_address->GetTensorStorageInfo() != nullptr) {
124       address = DeviceAddressUtils::ConvertContiguousDeviceAddress(device_context, tensor_address, is_sync);
125       input_tensor->set_device_address(address);
126     }
127     AnfAlgo::SetOutputAddr(address, 0, input_node.get());
128   }
129 }
130 
UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> & input_nodes,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context,bool is_sync)131 void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
132                                   const std::vector<tensor::BaseTensorPtr> &input_tensors,
133                                   const device::DeviceContext *device_context, bool is_sync) {
134   MS_LOG(DEBUG) << "Start";
135   auto input_size = input_nodes.size();
136   auto tensor_size = input_tensors.size();
137   if (input_size != tensor_size) {
138     MS_LOG(EXCEPTION) << "input node size:" << input_size << " not equal to tensors size:" << tensor_size;
139   }
140   for (size_t i = 0; i < input_size; ++i) {
141     auto &input_node = input_nodes[i];
142     auto &input_tensor = input_tensors[i];
143     MS_EXCEPTION_IF_NULL(input_tensor);
144     if (input_tensor->isa<tensor::MapTensor>()) {
145       auto map_tensor = input_tensor->cast<tensor::MapTensorPtr>();
146       MS_EXCEPTION_IF_NULL(map_tensor);
147       SetDeviceAddress(input_node, map_tensor, device_context, is_sync);
148       SetDeviceAddress(input_node, map_tensor->key_tensor(), device_context, is_sync);
149       SetDeviceAddress(input_node, map_tensor->value_tensor(), device_context, is_sync);
150       SetDeviceAddress(input_node, map_tensor->status_tensor(), device_context, is_sync);
151     } else {
152       SetDeviceAddress(input_node, input_tensor, device_context, is_sync);
153     }
154   }
155   MS_LOG(DEBUG) << "End";
156 }
157 
CopyTensorDataToDevice(const tensor::BaseTensorPtr & tensor,const AnfNodePtr & node,const device::DeviceContext * device_context)158 void CopyTensorDataToDevice(const tensor::BaseTensorPtr &tensor, const AnfNodePtr &node,
159                             const device::DeviceContext *device_context) {
160   MS_EXCEPTION_IF_NULL(tensor);
161   MS_EXCEPTION_IF_NULL(device_context);
162   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
163   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
164   MS_EXCEPTION_IF_CHECK_FAIL(device_address != nullptr, "Tensor device address is nullptr, id is " + tensor->id());
165   // Break copy data to device address if has the device_address has flag ignore.
166   if (TEST_FLAG(device_address->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
167     MS_LOG(DEBUG) << "Node " << node->DebugString() << " with address " << device_address
168                   << " has flag ignore device address, so skip copy tensor to device";
169     return;
170   }
171 
172   auto mem_type = tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
173   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
174                                                  device_address.get());
175   if ((device_address->GetPtr() == nullptr) &&
176       (!device_context->device_res_manager_->AllocateMemory(device_address.get()))) {
177     MS_LOG(EXCEPTION) << "Allocate memory failed, alloc size " << device_address->GetSize() << "B";
178   }
179   // Copy data from host tensor to device.
180   auto tensor_size = LongToSize(tensor->data().nbytes());
181   auto tensor_type = tensor->data_type();
182   MS_LOG(DEBUG) << "Copy to device, node:" << common::AnfAlgo::GetNodeDebugString(node);
183   if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), tensor_size, tensor_type,
184                                         "DefaultFormat", tensor->data_ptr())) {
185     MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
186   }
187 }
188 
CopyValueNodeDataToDevice(const KernelGraphPtr & graph,const device::DeviceContext * device_context)189 void CopyValueNodeDataToDevice(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
190   MS_EXCEPTION_IF_NULL(graph);
191   MS_LOG(DEBUG) << "Start";
192   const auto &value_nodes = graph->graph_value_nodes();
193   for (const auto &value_node : value_nodes) {
194     MS_EXCEPTION_IF_NULL(value_node);
195     const auto &node_value = value_node->value();
196     MS_EXCEPTION_IF_NULL(node_value);
197     if (!node_value->isa<tensor::BaseTensor>() && !node_value->isa<ValueTuple>() && !node_value->isa<Scalar>() &&
198         !node_value->isa<StringImm>()) {
199       MS_LOG(INFO) << "Unknown value node type:" << value_node->DebugString();
200       continue;
201     }
202 
203     const auto &node_address = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
204     MS_EXCEPTION_IF_NULL(node_address);
205     node_address->SetNodeIndex(value_node, 0);
206     if (node_address->GetPtr() != nullptr) {
207       continue;
208     }
209     auto shape = trans::GetRuntimePaddingShape(value_node, 0);
210     runtime::DeviceAddressUtils::CopyNoneTensorDataToDevice(device_context, node_address, shape);
211   }
212   MS_LOG(DEBUG) << "End";
213 }
214 
UpdateAddressSizeForDynamicShapeTensor(const tensor::BaseTensorPtr & input_tensor)215 void UpdateAddressSizeForDynamicShapeTensor(const tensor::BaseTensorPtr &input_tensor) {
216   MS_EXCEPTION_IF_NULL(input_tensor);
217   if (input_tensor->base_shape_ptr() != nullptr) {
218     auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
219     MS_EXCEPTION_IF_NULL(device_address);
220     auto tensor_size = LongToSize(input_tensor->data().nbytes());
221     if (tensor_size != device_address->GetSize()) {
222       device_address->SetSize(tensor_size);
223     }
224   }
225 }
226 
CopyMapTensorDataToDevice(const tensor::MapTensorPtr & map_tensor,const AnfNodePtr & input_node,const device::DeviceContext * device_context)227 void CopyMapTensorDataToDevice(const tensor::MapTensorPtr &map_tensor, const AnfNodePtr &input_node,
228                                const device::DeviceContext *device_context) {
229   MS_EXCEPTION_IF_NULL(map_tensor);
230   auto key_tensor = map_tensor->key_tensor();
231   MS_EXCEPTION_IF_NULL(key_tensor);
232   UpdateAddressSizeForDynamicShapeTensor(key_tensor);
233   CopyTensorDataToDevice(key_tensor, input_node, device_context);
234   key_tensor->set_sync_status(kNoNeedSync);
235   auto value_tensor = map_tensor->value_tensor();
236   MS_EXCEPTION_IF_NULL(value_tensor);
237   UpdateAddressSizeForDynamicShapeTensor(value_tensor);
238   CopyTensorDataToDevice(value_tensor, input_node, device_context);
239   value_tensor->set_sync_status(kNoNeedSync);
240   auto status_tensor = map_tensor->status_tensor();
241   MS_EXCEPTION_IF_NULL(status_tensor);
242   UpdateAddressSizeForDynamicShapeTensor(status_tensor);
243   CopyTensorDataToDevice(status_tensor, input_node, device_context);
244   status_tensor->set_sync_status(kNoNeedSync);
245 }
246 
CopyParameterDataToDevice(const std::vector<AnfNodePtr> & input_nodes,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context)247 void CopyParameterDataToDevice(const std::vector<AnfNodePtr> &input_nodes,
248                                const std::vector<tensor::BaseTensorPtr> &input_tensors,
249                                const device::DeviceContext *device_context) {
250   MS_LOG(DEBUG) << "Start";
251   auto input_size = input_nodes.size();
252   if (input_size > input_tensors.size()) {
253     MS_LOG(EXCEPTION) << "input_size is bigger than input_tensors size, input_size:" << input_size
254                       << ", input_tensors size:" << input_tensors.size();
255   }
256   for (size_t i = 0; i < input_size; ++i) {
257     MS_EXCEPTION_IF_NULL(input_tensors[i]);
258     if (input_tensors[i]->NeedSyncHostToDeviceImmediately()) {
259       // First op in dynamic shape scenario(feed mode)
260       if (input_tensors[i]->isa<tensor::MapTensor>()) {
261         auto map_tensor = input_tensors[i]->cast<tensor::MapTensorPtr>();
262         CopyMapTensorDataToDevice(map_tensor, input_nodes[i], device_context);
263       } else {
264         UpdateAddressSizeForDynamicShapeTensor(input_tensors[i]);
265         CopyTensorDataToDevice(input_tensors[i], input_nodes[i], device_context);
266         input_tensors[i]->set_sync_status(kNoNeedSync);
267       }
268     }
269   }
270   MS_LOG(DEBUG) << "End";
271 }
272 
MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const device::DeviceContext * device_context,const CNodePtr & node)273 bool MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
274                           const device::DeviceContext *device_context, const CNodePtr &node) {
275   auto kernel_mod = AnfAlgo::GetKernelMod(node);
276   MS_EXCEPTION_IF_NULL(runtime_info);
277   MS_EXCEPTION_IF_NULL(device_context);
278   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
279   auto input_size = runtime_info->GetInputSize();
280   for (size_t i = 0; i < input_size; ++i) {
281     if (common::AnfAlgo::IsNoneInput(node, i)) {
282       MS_EXCEPTION_IF_NULL(node);
283       MS_LOG(DEBUG) << "Input [" << i << "] of " << node->fullname_with_scope() << " is None, no need to allocate.";
284       continue;
285     }
286     auto input_address = runtime_info->GetInputDeviceAddress(i);
287     MS_EXCEPTION_IF_NULL(kernel_mod);
288     MS_EXCEPTION_IF_NULL(input_address);
289     if (TEST_FLAG(input_address->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
290       MS_LOG(DEBUG) << "Node " << node->DebugString() << " input[" << i << "] with address " << input_address
291                     << " has flag ignore device address, so skip malloc device address";
292       continue;
293     }
294     if (input_address->GetPtr() == nullptr) {
295       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
296                                                      input_address->GetSize(), input_address.get());
297       if (!device_context->device_res_manager_->AllocateMemory(input_address.get())) {
298         MS_LOG(EXCEPTION) << "Allocate memory failed, alloc size " << input_address->GetSize() << "B";
299       }
300     }
301   }
302   return true;
303 }
304 
MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const AnfNodePtr & node,const device::DeviceContext * device_context)305 bool MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> &runtime_info, const AnfNodePtr &node,
306                            const device::DeviceContext *device_context) {
307   MS_EXCEPTION_IF_NULL(runtime_info);
308   MS_EXCEPTION_IF_NULL(node);
309   MS_EXCEPTION_IF_NULL(device_context);
310   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
311 
312   auto kernel_mod = AnfAlgo::GetKernelMod(node);
313   MS_EXCEPTION_IF_NULL(kernel_mod);
314   auto output_size = runtime_info->GetOutputSize();
315   auto kernel_out_size_list = kernel_mod->GetOutputSizeList();
316   if (kernel_out_size_list.size() != output_size) {
317     MS_LOG(ERROR) << "Node " << node->fullname_with_scope() << " output num is:" << output_size
318                   << " but kernel_mod output num:" << kernel_out_size_list.size();
319     return false;
320   }
321   for (size_t i = 0; i < output_size; ++i) {
322     auto device_address = runtime_info->GetOutputDeviceAddress(i);
323     MS_EXCEPTION_IF_NULL(device_address);
324     // For example, we need to call cudnnGetRNNTrainingReserveSize to get real output size in LstmGpuKernelMod!
325     if (kernel_out_size_list[i] != device_address->GetSize() &&
326         AnfAlgo::GetOutputFormat(node, i) == device_address->format()) {
327       // If the format of the DeviceAddress is different, then the size is originally different.
328       // Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size.
329       if (device_address->GetPtr() != nullptr) {
330         MS_LOG(ERROR) << "kernel mod output " << i << " size:" << kernel_out_size_list[i]
331                       << " not equal to device_address size:" << device_address->GetSize()
332                       << ", but the device address is already have ptr";
333         return false;
334       }
335       device_address->SetSize(kernel_out_size_list[i]);
336     }
337     if (device_address->GetPtr() == nullptr) {
338       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
339                                                      device_address->GetSize(), device_address.get());
340       if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
341         MS_LOG(EXCEPTION) << "Allocate output memory failed, alloc node:" << node->fullname_with_scope()
342                           << " alloc size:" << device_address->GetSize() << "B";
343       }
344     }
345   }
346   return true;
347 }
348 
GetInputKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const AnfNodePtr & node)349 std::vector<kernel::KernelTensor *> GetInputKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
350                                                           const AnfNodePtr &node) {
351   MS_EXCEPTION_IF_NULL(runtime_info);
352   auto input_size = runtime_info->GetInputSize();
353   std::vector<kernel::KernelTensor *> inputs;
354   for (size_t i = 0; i < input_size; ++i) {
355     auto device_address = runtime_info->GetInputDeviceAddress(i);
356     MS_EXCEPTION_IF_NULL(device_address);
357     (void)inputs.emplace_back(device_address->kernel_tensor().get());
358     MS_EXCEPTION_IF_NULL(inputs.back());
359     MS_LOG(DEBUG) << "input[" << i << "]:" << inputs.back()->device_ptr() << " size:" << inputs.back()->size();
360   }
361   return inputs;
362 }
363 
GetInputKernelTensorsForInfer(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const AnfNodePtr & node)364 std::vector<abstract::AbstractBasePtr> GetInputKernelTensorsForInfer(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
365                                                                      const AnfNodePtr &node) {
366   MS_EXCEPTION_IF_NULL(runtime_info);
367   auto input_size = runtime_info->GetInputSize();
368   std::vector<abstract::AbstractBasePtr> inputs;
369   for (size_t i = 0; i < input_size; ++i) {
370     auto device_address = runtime_info->GetInputDeviceAddress(i);
371     MS_EXCEPTION_IF_NULL(device_address);
372     (void)inputs.emplace_back(device_address->kernel_tensor());
373     MS_EXCEPTION_IF_NULL(inputs.back());
374   }
375   return inputs;
376 }
377 
GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const device::DeviceContext * device_context,size_t workspace_size,size_t workspace_sizes)378 std::vector<kernel::KernelTensor *> GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
379                                                               const device::DeviceContext *device_context,
380                                                               size_t workspace_size, size_t workspace_sizes) {
381   std::vector<kernel::KernelTensor *> workspaces;
382   for (size_t i = 0; i < workspace_size && i < workspace_sizes; ++i) {
383     auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
384     MS_EXCEPTION_IF_NULL(device_address);
385     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
386                                                    device_address->GetSize(), device_address.get());
387     if (device_address->GetPtr() == nullptr &&
388         !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
389       MS_LOG(EXCEPTION) << "Allocate workspace memory failed, alloc size:" << device_address->GetSize() << "B";
390     }
391     (void)workspaces.emplace_back(device_address->kernel_tensor().get());
392     MS_EXCEPTION_IF_NULL(workspaces.back());
393     MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->device_ptr()
394                   << " size:" << workspaces.back()->size();
395   }
396   return workspaces;
397 }
398 
GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const device::DeviceContext * device_context,const CNodePtr & kernel,bool is_dynamic_shape,bool is_dynamic_value)399 std::vector<kernel::KernelTensor *> GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
400                                                               const device::DeviceContext *device_context,
401                                                               const CNodePtr &kernel, bool is_dynamic_shape,
402                                                               bool is_dynamic_value) {
403   MS_EXCEPTION_IF_NULL(runtime_info);
404   MS_EXCEPTION_IF_NULL(device_context);
405   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
406   auto workspace_size = runtime_info->GetWorkspaceSize();
407   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
408   MS_EXCEPTION_IF_NULL(kernel_mod);
409   auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
410 
411   std::vector<device::DeviceAddressPtr> add_workspaces;
412   if (is_dynamic_shape || is_dynamic_value) {
413     // Resize of workspaces, because of the dynamic size of workspace.
414     if (workspace_size < workspace_sizes.size()) {
415       for (size_t i = workspace_size; i < workspace_sizes.size(); ++i) {
416         auto kernel_tensor = std::make_shared<KernelTensor>(
417           nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
418           device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
419         auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
420         MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
421                       << " addr:" << device_address;
422         AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());  // set to kernel_info
423         MS_EXCEPTION_IF_NULL(device_address);
424         (void)add_workspaces.emplace_back(device_address);
425       }
426     }
427   }
428 
429   // Set workspace address new size
430   for (size_t i = 0; i < workspace_size && i < workspace_sizes.size(); ++i) {
431     auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
432     MS_EXCEPTION_IF_NULL(device_address);
433     device_address->SetSize(workspace_sizes[i]);
434   }
435 
436   std::vector<kernel::KernelTensor *> workspaces =
437     GetWorkspaceKernelTensors(runtime_info, device_context, workspace_size, workspace_sizes.size());
438   for (size_t i = workspace_size; i < workspace_sizes.size(); ++i) {
439     auto device_address = add_workspaces[i];
440     MS_EXCEPTION_IF_NULL(device_address);
441     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
442                                                    device_address->GetSize(), device_address.get());
443     if (device_address->GetPtr() == nullptr &&
444         !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
445       MS_LOG(EXCEPTION) << "Allocate workspace memory failed, alloc size:" << device_address->GetSize() << "B";
446     }
447     (void)workspaces.emplace_back(device_address->kernel_tensor().get());
448     MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->device_ptr()
449                   << " size:" << workspaces.back()->size();
450   }
451   return workspaces;
452 }
453 
GetWorkspaceKernelTensorsDynamic(const device::DeviceContext * device_context,const CNodePtr & kernel,std::vector<device::DeviceAddressPtr> * workspace_device_address)454 std::vector<kernel::KernelTensor *> GetWorkspaceKernelTensorsDynamic(
455   const device::DeviceContext *device_context, const CNodePtr &kernel,
456   std::vector<device::DeviceAddressPtr> *workspace_device_address) {
457   MS_EXCEPTION_IF_NULL(device_context);
458   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
459   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
460   MS_EXCEPTION_IF_NULL(kernel_mod);
461   auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
462 
463   std::vector<kernel::KernelTensor *> workspaces;
464   workspaces.reserve(workspace_sizes.size());
465   for (size_t i = 0; i < workspace_sizes.size(); ++i) {
466     auto kernel_tensor = std::make_shared<KernelTensor>(
467       nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
468       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
469     auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
470     MS_EXCEPTION_IF_NULL(device_address);
471     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
472                                                    device_address->GetSize(), device_address.get());
473     if (device_address->GetPtr() == nullptr &&
474         !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
475       MS_LOG(EXCEPTION) << "Allocate dynamic workspace memory failed, alloc size:" << device_address->GetSize() << "B";
476     }
477     MS_EXCEPTION_IF_NULL(workspace_device_address);
478     (void)workspace_device_address->emplace_back(device_address);
479     (void)workspaces.emplace_back(device_address->kernel_tensor().get());
480     MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->device_ptr()
481                   << " size:" << workspaces.back()->size();
482   }
483   return workspaces;
484 }
485 
GetOutputKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info)486 std::vector<kernel::KernelTensor *> GetOutputKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
487   MS_EXCEPTION_IF_NULL(runtime_info);
488   auto output_size = runtime_info->GetOutputSize();
489   std::vector<kernel::KernelTensor *> outputs;
490   for (size_t i = 0; i < output_size; ++i) {
491     auto device_address = runtime_info->GetOutputDeviceAddress(i);
492     MS_EXCEPTION_IF_NULL(device_address);
493     (void)outputs.emplace_back(device_address->kernel_tensor().get());
494     MS_LOG(DEBUG) << "output[" << i << "]:" << outputs.back()->device_ptr() << " size:" << outputs.back()->size();
495   }
496   return outputs;
497 }
498 
499 // Host to Device or Device to Host
CopyDataToDevice(const KernelGraphPtr & graph,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context)500 void CopyDataToDevice(const KernelGraphPtr &graph, const std::vector<tensor::BaseTensorPtr> &input_tensors,
501                       const device::DeviceContext *device_context) {
502   MS_EXCEPTION_IF_NULL(graph);
503   CopyValueNodeDataToDevice(graph, device_context);
504   CopyParameterDataToDevice(graph->input_nodes(), input_tensors, device_context);
505 }
506 
InferNodeRealShape(const CNodePtr & kernel,const std::vector<abstract::AbstractBasePtr> & input_args)507 BaseShapePtr InferNodeRealShape(const CNodePtr &kernel, const std::vector<abstract::AbstractBasePtr> &input_args) {
508   MS_EXCEPTION_IF_NULL(kernel);
509   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInfer,
510                                      kernel->fullname_with_scope(), false);
511   auto *kernel_mod = AnfAlgo::GetKernelMod(kernel);
512   MS_EXCEPTION_IF_NULL(kernel_mod);
513   return opt::dynamic_shape::InferShape(kernel_mod->primitive(), input_args);
514 }
515 
ResizeKernelMod(const CNodePtr & kernel,const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs)516 void ResizeKernelMod(const CNodePtr &kernel, const std::vector<kernel::KernelTensor *> &inputs,
517                      const std::vector<kernel::KernelTensor *> &outputs) {
518   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelResize,
519                                      kernel->fullname_with_scope(), false);
520   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
521   MS_EXCEPTION_IF_NULL(kernel_mod);
522   kernel_mod->set_use_kernel_tensor(true);
523 
524   int ret = kernel_mod->Resize(inputs, outputs);
525   if (ret != kernel::KRET_OK) {
526     MS_LOG(EXCEPTION) << "Resize failed for kernel: " << kernel->fullname_with_scope();
527   }
528 }
529 
SetOutputDeviceAddressFlag(const pynative::OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info)530 void SetOutputDeviceAddressFlag(const pynative::OpCompilerInfoPtr &op_compiler_info,
531                                 const session::BackendOpRunInfoPtr &op_run_info) {
532   MS_EXCEPTION_IF_NULL(op_compiler_info);
533   MS_EXCEPTION_IF_NULL(op_run_info);
534   const auto &simple_graph = op_compiler_info->simple_graph_;
535   size_t output_size = simple_graph->outputs_.size();
536   // Reset grad output flag.
537   const auto &outputs = simple_graph->outputs_;
538   for (const auto &output : outputs) {
539     output->is_grad_ = false;
540   }
541 
542   if (op_run_info->is_gradient_out) {
543     const auto &output_indexes = op_run_info->base_op_run_info.output_indexes;
544     for (auto index : output_indexes) {
545       if (index >= output_size) {
546         MS_LOG(EXCEPTION) << "Gradient output index " << index << " >= graph output size " << output_size;
547       }
548       const auto &output = outputs[index];
549       MS_EXCEPTION_IF_NULL(output);
550       output->is_grad_ = true;
551       MS_LOG(DEBUG) << "Set grad flag for op " << op_run_info->base_op_run_info.op_name << " index " << index;
552     }
553   }
554 }
555 
MallocForConstValue(const pynative::OpCompilerInfoPtr & op_compiler_info)556 void MallocForConstValue(const pynative::OpCompilerInfoPtr &op_compiler_info) {
557   MS_EXCEPTION_IF_NULL(op_compiler_info);
558   const auto &device_context = op_compiler_info->device_context_;
559   const auto &graph = op_compiler_info->graph_;
560   CopyValueNodeDataToDevice(graph, device_context);
561 }
562 
UpdateOutputShape(const std::vector<EdgePtr> & output_edges)563 void UpdateOutputShape(const std::vector<EdgePtr> &output_edges) {
564   for (const auto &edge : output_edges) {
565     MS_EXCEPTION_IF_NULL(edge);
566     const auto &device_address = edge->address_;
567     MS_EXCEPTION_IF_NULL(device_address);
568     const auto &kernel_tensor = device_address->kernel_tensor();
569     MS_EXCEPTION_IF_NULL(kernel_tensor);
570     device_address->set_host_shape(kernel_tensor->host_info_exist() ? kernel_tensor->GetShapeVector()
571                                                                     : kernel_tensor->host_shape());
572   }
573 }
574 
LaunchKernels(const KernelGraphPtr & graph,const device::DeviceContext * device_context,const session::BackendOpRunInfoPtr & op_run_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)575 void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *device_context,
576                    const session::BackendOpRunInfoPtr &op_run_info,
577                    const std::vector<tensor::BaseTensorPtr> &input_tensors) {
578   MS_EXCEPTION_IF_NULL(graph);
579   MS_EXCEPTION_IF_NULL(device_context);
580   MS_LOG(DEBUG) << "Start";
581 
582   // Get device address from OpRuntimeInfo
583   const auto &execution_order = graph->execution_order();
584   for (auto const &node : execution_order) {
585     MS_EXCEPTION_IF_NULL(node);
586     MS_LOG(DEBUG) << "Start launch kernel " << node->fullname_with_scope() << " kernel type "
587                   << AnfAlgo::GetKernelType(node);
588     auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(node);
589     bool is_dynamic_value = common::AnfAlgo::IsDynamicValue(node);
590     auto runtime_info = node->user_data<runtime::OpRuntimeInfo>();
591     MS_EXCEPTION_IF_NULL(runtime_info);
592 
593     if (!MallocForKernelInput(runtime_info, device_context, node)) {
594       MS_LOG(EXCEPTION) << "Malloc for kernel input failed, Memory isn't enough, node:" << node->fullname_with_scope();
595     }
596 
597     auto inputs = GetInputKernelTensors(runtime_info, node);
598     auto outputs = GetOutputKernelTensors(runtime_info);
599     if (is_dynamic_shape) {
600       auto input_kernel_tensors_for_infer = GetInputKernelTensorsForInfer(runtime_info, node);
601       auto out_shape = InferNodeRealShape(node, input_kernel_tensors_for_infer);
602       opt::dynamic_shape::UpdateKernelTensorShape(out_shape, outputs);
603       ResizeKernelMod(node, inputs, outputs);
604     } else if (is_dynamic_value) {
605       auto kernel_mod = runtime_info->GetKernelMod();
606       MS_EXCEPTION_IF_NULL(kernel_mod);
607       if (kernel_mod->Resize(inputs, outputs) != static_cast<int>(kernel::KRET_OK)) {
608         MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " resize failed";
609       }
610     }
611     auto workspaces = GetWorkspaceKernelTensors(runtime_info, device_context, node, is_dynamic_shape, is_dynamic_value);
612 
613     if (!MallocForKernelOutput(runtime_info, node, device_context)) {
614       MS_LOG(EXCEPTION) << "Malloc for kernel output failed, Memory isn't enough, node:" << node->fullname_with_scope();
615     }
616 
617     MS_EXCEPTION_IF_NULL(device_context);
618     MS_EXCEPTION_IF_NULL(device_context->GetKernelExecutor(true));
619     auto kernel_mod = AnfAlgo::GetKernelMod(node);
620     const size_t stream_id = op_run_info->base_op_run_info.stream_id;
621     auto stream = device_context->device_res_manager_->GetStream(stream_id);
622     if (!device_context->GetKernelExecutor(false)->LaunchKernel(node, inputs, workspaces, outputs, kernel_mod,
623                                                                 stream)) {
624       MS_LOG(EXCEPTION) << "Launch kernel failed, name:" << node->fullname_with_scope();
625     }
626     runtime::DeviceAddressUtils::ProcessCrossStreamAddress(op_run_info->base_op_run_info.op_name, device_context,
627                                                            stream_id, inputs, outputs);
628   }
629   MS_LOG(DEBUG) << "End";
630 }
631 
AllocateOutputMemory(const std::vector<EdgePtr> & output_edges,const device::DeviceContext * device_context)632 void AllocateOutputMemory(const std::vector<EdgePtr> &output_edges, const device::DeviceContext *device_context) {
633   MS_EXCEPTION_IF_NULL(device_context);
634   for (const auto &edge : output_edges) {
635     MS_EXCEPTION_IF_NULL(edge);
636     const auto &device_address = edge->address_;
637     MS_EXCEPTION_IF_NULL(device_address);
638     if (device_address->GetPtr() == nullptr) {
639       if (edge->is_grad_) {
640         device_address->set_from_persistent_mem(true);
641       }
642       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
643                                                      device_address->GetSize(), device_address.get());
644       MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
645       if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
646         MS_LOG(EXCEPTION) << "Allocate device memory failed, alloc size:" << device_address->GetSize() << "B";
647       }
648     }
649   }
650 }
651 
UpdateOutputDeviceInfo(const std::vector<EdgePtr> & edges,const CNodePtr & kernel)652 void UpdateOutputDeviceInfo(const std::vector<EdgePtr> &edges, const CNodePtr &kernel) {
653   MS_EXCEPTION_IF_NULL(kernel);
654   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
655   MS_EXCEPTION_IF_NULL(kernel_mod);
656   auto output_size_list = kernel_mod->GetOutputSizeList();
657   if (edges.size() != output_size_list.size()) {
658     MS_LOG(EXCEPTION) << "Output device address's size " << edges.size() << " is not equal output_size_list's size "
659                       << output_size_list.size();
660   }
661 
662   auto output_num = edges.size();
663   for (size_t i = 0; i < output_num; ++i) {
664     const auto &edge = edges[i];
665     MS_EXCEPTION_IF_NULL(edge);
666     const auto &device_address = edge->address_;
667     MS_EXCEPTION_IF_NULL(device_address);
668     const auto &kernel_tensor = device_address->kernel_tensor();
669     MS_EXCEPTION_IF_NULL(kernel_tensor);
670     device_address->set_host_shape(kernel_tensor->GetShapeVector());
671     device_address->SetSize(output_size_list[i]);
672   }
673 }
674 
UpdateInputTensorForHeterogeneous(const DeviceContext * device_context,const tensor::BaseTensorPtr & input_tensor,const device::DeviceAddressPtr & cached_device_address)675 void UpdateInputTensorForHeterogeneous(const DeviceContext *device_context, const tensor::BaseTensorPtr &input_tensor,
676                                        const device::DeviceAddressPtr &cached_device_address) {
677   MS_EXCEPTION_IF_NULL(device_context);
678   MS_EXCEPTION_IF_NULL(cached_device_address);
679   MS_EXCEPTION_IF_NULL(input_tensor);
680   auto device_sync = input_tensor->device_address();
681   if (device_sync == nullptr) {
682     return;
683   }
684   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
685   MS_EXCEPTION_IF_NULL(device_address);
686   if (device_address->GetDeviceType() != device_context->GetDeviceType() ||
687       device_address->format() != cached_device_address->format()) {
688     // Need wait for OpExecutor task finish
689     input_tensor->data_sync();
690     // If tensor address is null, we will set Parameter address to the Tensor.
691     input_tensor->set_device_address(nullptr);
692   }
693 }
694 
UpdateAddressInfoByInputTensor(const OpCompilerInfoPtr & op_compiler_info,const tensor::BaseTensorPtr & tensor,const EdgePtr & edge,const AnfNodePtr & node)695 void UpdateAddressInfoByInputTensor(const OpCompilerInfoPtr &op_compiler_info, const tensor::BaseTensorPtr &tensor,
696                                     const EdgePtr &edge, const AnfNodePtr &node) {
697   MS_EXCEPTION_IF_NULL(tensor);
698   MS_EXCEPTION_IF_NULL(node);
699   auto &device_context = op_compiler_info->device_context_;
700   MS_EXCEPTION_IF_NULL(device_context);
701   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
702 
703   auto origin_address = edge->origin_address_;
704 
705   const auto &format = origin_address->format();
706   const auto dtype = origin_address->type_id();
707   const auto &shape = tensor->shape();
708   size_t tensor_size = DeviceAddressUtils::GetTensorDeviceSize(device_context, node, shape, format, dtype, 0);
709 
710   const auto &kernel_tensor = origin_address->kernel_tensor();
711   MS_EXCEPTION_IF_NULL(kernel_tensor);
712   auto new_kernel_tensor = kernel_tensor->CloneKernelTensor();
713   MS_EXCEPTION_IF_NULL(new_kernel_tensor);
714 
715   new_kernel_tensor->SetShapeVector(shape);
716   new_kernel_tensor->set_device_ptr(nullptr);
717   auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
718   MS_EXCEPTION_IF_NULL(new_device_address);
719   new_device_address->set_host_shape(shape);
720   new_device_address->SetSize(tensor_size);
721   new_device_address->set_from_persistent_mem(tensor->is_parameter());
722   edge->address_ = new_device_address;
723 }
724 
GetInputKernelTensors(const std::vector<EdgePtr> & edges)725 std::vector<kernel::KernelTensor *> GetInputKernelTensors(const std::vector<EdgePtr> &edges) {
726   std::vector<kernel::KernelTensor *> input_kernel_tensors;
727   input_kernel_tensors.reserve(edges.size());
728   (void)std::transform(edges.begin(), edges.end(), std::back_inserter(input_kernel_tensors), [](const EdgePtr &edge) {
729     MS_EXCEPTION_IF_NULL(edge->address_);
730     return edge->address_->kernel_tensor().get();
731   });
732   return input_kernel_tensors;
733 }
734 
GetInputInferAbstract(const std::vector<EdgePtr> & edges)735 std::vector<abstract::AbstractBasePtr> GetInputInferAbstract(const std::vector<EdgePtr> &edges) {
736   std::vector<abstract::AbstractBasePtr> input_abstracts;
737   input_abstracts.reserve(edges.size());
738   (void)std::transform(edges.begin(), edges.end(), std::back_inserter(input_abstracts), [](const EdgePtr &edge) {
739     MS_EXCEPTION_IF_NULL(edge->address_);
740     return edge->address_->kernel_tensor();
741   });
742   return input_abstracts;
743 }
744 
GetOutputKernelTensors(const std::vector<EdgePtr> & edges,const DeviceContext * device_context)745 std::vector<kernel::KernelTensor *> GetOutputKernelTensors(const std::vector<EdgePtr> &edges,
746                                                            const DeviceContext *device_context) {
747   std::vector<kernel::KernelTensor *> output_kernel_tensors;
748   output_kernel_tensors.reserve(edges.size());
749   for (const auto &edge : edges) {
750     // For example, output is dynamic or the output is between two ops.
751     if (edge->address_ == nullptr) {
752       edge->address_ = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(edge->origin_address_, device_context);
753     }
754     const auto &output_address = edge->address_;
755     MS_EXCEPTION_IF_NULL(output_address);
756     output_kernel_tensors.push_back(output_address->kernel_tensor().get());
757   }
758   return output_kernel_tensors;
759 }
760 }  // namespace
761 
GetTensorWithoutValueMask(const session::BackendOpRunInfoPtr & op_run_info)762 std::vector<tensor::BaseTensorPtr> OpRunner::GetTensorWithoutValueMask(
763   const session::BackendOpRunInfoPtr &op_run_info) {
764   MS_EXCEPTION_IF_NULL(op_run_info);
765   std::vector<tensor::BaseTensorPtr> tensors_without_value_node;
766   const auto &input_values = op_run_info->base_op_run_info.expanded_input_values;
767   const auto &input_masks = op_run_info->base_op_run_info.input_types;
768   if (input_values.size() != input_masks.size()) {
769     MS_LOG(EXCEPTION) << "Input tensors size " << input_values.size() << " should be equal to tensors mask size "
770                       << input_masks.size();
771   }
772   for (size_t index = 0; index < input_masks.size(); ++index) {
773     runtime::DeviceAddressUtils::CreateKernelTensor(input_values[index]);
774     if (input_masks.at(index) != InputType::kConstant) {
775       if (!input_values[index]->isa<tensor::BaseTensor>()) {
776         MS_LOG(EXCEPTION) << "The " << index << "' input shoulde be a Tensor, but got "
777                           << input_values[index]->ToString();
778       }
779       (void)tensors_without_value_node.emplace_back(input_values.at(index)->cast<tensor::BaseTensorPtr>());
780     }
781   }
782   return tensors_without_value_node;
783 }
784 
785 // Determine the address of the graph and do not change the address in subsequent executions
UpdateDeviceAddress(const KernelGraphPtr & graph,const std::vector<tensor::BaseTensorPtr> & tensors_without_value_mask,const device::DeviceContext * device_context,bool is_sync)786 void OpRunner::UpdateDeviceAddress(const KernelGraphPtr &graph,
787                                    const std::vector<tensor::BaseTensorPtr> &tensors_without_value_mask,
788                                    const device::DeviceContext *device_context, bool is_sync) {
789   MS_EXCEPTION_IF_NULL(graph);
790   MS_LOG(DEBUG) << "Start";
791   const auto &input_nodes = graph->input_nodes();
792   UpdateInputTensorFromDevice(input_nodes, tensors_without_value_mask, device_context);
793   UpdateInputNodeDeviceAddress(input_nodes, tensors_without_value_mask, device_context, is_sync);
794   pynative::OpCompiler::UpdateRefNodeOutputDeviceAddress(graph);
795   MS_LOG(DEBUG) << "End";
796 }
797 
RunSingleOpGraph(const session::BackendOpRunInfoPtr & op_run_info,const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)798 void OpRunner::RunSingleOpGraph(const session::BackendOpRunInfoPtr &op_run_info,
799                                 const OpCompilerInfoPtr &op_compiler_info,
800                                 const std::vector<tensor::BaseTensorPtr> &input_tensors) {
801   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "PyNative", op_run_info->base_op_run_info.op_name,
802                                                  op_compiler_info->graph_->ToString());
803   CopyDataToDevice(op_compiler_info->graph_, input_tensors, op_compiler_info->device_context_);
804   LaunchKernels(op_compiler_info->graph_, op_compiler_info->device_context_, op_run_info, input_tensors);
805 }
806 
LaunchKernelTask(const runtime::KernelTaskType & task_type,DeviceContext * device_context,const device::DeviceAddressPtrList & input_addr_list,const device::DeviceAddressPtrList & output_addr_list,size_t stream_id)807 void OpRunner::LaunchKernelTask(const runtime::KernelTaskType &task_type, DeviceContext *device_context,
808                                 const device::DeviceAddressPtrList &input_addr_list,
809                                 const device::DeviceAddressPtrList &output_addr_list, size_t stream_id) {
810   MS_EXCEPTION_IF_NULL(device_context);
811   MS_LOG(DEBUG) << "Start, task_type:" << task_type;
812   if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(task_type, input_addr_list, output_addr_list,
813                                                                    stream_id)) {
814     MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << task_type;
815   }
816   MS_LOG(DEBUG) << "End";
817 }
818 
GetDeviceContext(const std::string & device_type)819 DeviceContext *OpRunner::GetDeviceContext(const std::string &device_type) {
820   auto type_iter = device::device_name_to_type_map.find(device_type);
821   if (type_iter == device::device_name_to_type_map.end()) {
822     MS_LOG(EXCEPTION) << "Invalid device_type " << device_type;
823   }
824 
825   auto index = static_cast<size_t>(type_iter->second);
826   auto cached_device_context = kDeviceContexts[index];
827 
828   if (cached_device_context != nullptr) {
829     return cached_device_context;
830   }
831 
832   GilReleaseWithCheck release_gil;
833   std::unique_lock<std::mutex> lock(*kDeviceContextMutex);
834 
835   auto device_id = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
836   auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_type, device_id});
837   MS_EXCEPTION_IF_NULL(device_context);
838   device_context->Initialize();
839 
840   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
841   (void)device_context->device_res_manager_->BindDeviceToCurrentThread(false);
842   kDeviceContexts[index] = device_context;
843   MS_LOG(DEBUG) << "Get device context of " << device_type << " id " << device_id;
844   return device_context;
845 }
846 
ChildAfterFork()847 void OpRunner::ChildAfterFork() {
848   kDeviceContexts.fill(nullptr);
849   kDeviceContextMutex = std::make_unique<std::mutex>();
850 }
851 
RunSingleOpGraph(const session::BackendOpRunInfoPtr & op_run_info,const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)852 void DynamicOpRunner::RunSingleOpGraph(const session::BackendOpRunInfoPtr &op_run_info,
853                                        const OpCompilerInfoPtr &op_compiler_info,
854                                        const std::vector<tensor::BaseTensorPtr> &input_tensors) {
855   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "PyNative", op_run_info->base_op_run_info.op_name,
856                                                  op_compiler_info->graph_->ToString());
857   DynamicOpRunner::CopyHostToDevice(op_compiler_info, input_tensors);
858   MallocForConstValue(op_compiler_info);
859 
860   const auto &simple_graph = op_compiler_info->simple_graph_;
861   const auto &single_ops = simple_graph->single_ops_;
862   bool is_need_infer = false;
863   auto op_num = single_ops.size();
864   MS_EXCEPTION_IF_NULL(op_run_info->base_op_run_info.abstract);
865   if (op_num > 1 || op_run_info->base_op_run_info.abstract->BuildShape()->IsDynamic()) {
866     is_need_infer = true;
867   }
868 
869   SetOutputDeviceAddressFlag(op_compiler_info, op_run_info);
870 
871   const auto *device_context = op_compiler_info->device_context_;
872   // Execute all kernels
873   for (size_t i = 0; i < op_num; ++i) {
874     const auto &single_op = single_ops[i];
875     const CNodePtr &kernel = single_op->kernel_;
876     MS_EXCEPTION_IF_NULL(kernel);
877 
878     // Fetch input kernel tensor.
879     const auto &input_edges = single_op->inputs_;
880     const auto &output_edges = single_op->outputs_;
881 
882     const auto &input_kernel_tensors = GetInputKernelTensors(input_edges);
883     const auto &input_abstracts = GetInputInferAbstract(input_edges);
884     const auto &output_kernel_tensors = GetOutputKernelTensors(output_edges, device_context);
885 
886     BaseShapePtr out_shape;
887     if (is_need_infer) {
888       out_shape = InferNodeRealShape(kernel, input_abstracts);
889     } else {
890       kernel->set_abstract(op_run_info->base_op_run_info.abstract);
891       out_shape = op_run_info->base_op_run_info.abstract->GetShape();
892     }
893     // Update output kernel tensor.
894     opt::dynamic_shape::UpdateKernelTensorShape(out_shape, output_kernel_tensors);
895 
896     // Resize
897     ResizeKernelMod(kernel, input_kernel_tensors, output_kernel_tensors);
898 
899     // Malloc workspace memory
900     std::vector<device::DeviceAddressPtr> workspace_device_address;
901     auto workspace_kernel_tensors = GetWorkspaceKernelTensorsDynamic(device_context, kernel, &workspace_device_address);
902 
903     // Update output tensor shape
904     UpdateOutputDeviceInfo(output_edges, kernel);
905 
906     // Malloc output tensor memory
907     AllocateOutputMemory(output_edges, device_context);
908 
909     // Launch kernel
910     MS_EXCEPTION_IF_NULL(device_context);
911     MS_EXCEPTION_IF_NULL(device_context->GetKernelExecutor(true));
912     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
913     MS_EXCEPTION_IF_NULL(kernel_mod);
914     const size_t stream_id = op_run_info->base_op_run_info.stream_id;
915     auto stream = device_context->device_res_manager_->GetStream(stream_id);
916     if (!device_context->GetKernelExecutor(true)->LaunchKernel(kernel, input_kernel_tensors, workspace_kernel_tensors,
917                                                                output_kernel_tensors, kernel_mod, stream)) {
918       MS_LOG(EXCEPTION) << "Launch kernel failed, name:" << kernel->fullname_with_scope();
919     }
920 
921     if (is_need_infer) {
922       if (kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
923         kernel_mod->UpdateOutputShapeAndSize(input_kernel_tensors, output_kernel_tensors);
924         UpdateOutputShape(output_edges);
925       }
926     }
927     runtime::DeviceAddressUtils::ProcessCrossStreamAddress(op_run_info->base_op_run_info.op_name, device_context,
928                                                            stream_id, input_kernel_tensors, output_kernel_tensors);
929   }
930 }
931 
UpdateInputDeviceAddress(const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors,bool is_sync)932 void DynamicOpRunner::UpdateInputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info,
933                                                const std::vector<tensor::BaseTensorPtr> &input_tensors, bool is_sync) {
934   MS_LOG(DEBUG) << "Start update input device address for " << op_compiler_info->graph_info_;
935   const auto &simple_graph = op_compiler_info->simple_graph_;
936   auto input_tensors_num = input_tensors.size();
937   auto op_input_num = simple_graph->inputs_.size();
938   if (input_tensors_num != op_input_num) {
939     MS_LOG(EXCEPTION) << "Real input tensor's num " << input_tensors_num << " is not equal to op input num"
940                       << op_input_num << " !";
941   }
942   const auto &device_context = op_compiler_info->device_context_;
943   const auto &inputs = simple_graph->inputs_;
944   for (size_t i = 0; i < input_tensors_num; ++i) {
945     const auto &input_tensor = input_tensors[i];
946     MS_EXCEPTION_IF_NULL(input_tensor);
947     const auto &input_edge = inputs[i];
948     // input_edge->address_ is null.
949     UpdateInputTensorForHeterogeneous(device_context, input_tensor, input_edge->origin_address_);
950     const auto &device_sync = input_tensor->device_address();
951     const auto &device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
952 
953     const auto &input_node = input_edge->node_with_index_.first;
954     common::AnfAlgo::SetOutputInferTypeAndShape({input_tensor->data_type()}, {input_tensor->shape()}, input_node.get());
955     if (device_address != nullptr) {
956       if (device_address->GetTensorStorageInfo() != nullptr) {
957         auto new_device_address =
958           DeviceAddressUtils::ConvertContiguousDeviceAddress(device_context, device_address, is_sync);
959         input_edge->address_ = new_device_address;
960         input_tensor->set_device_address(new_device_address);
961       } else {
962         // Always use tensor address as kernel address.
963         input_edge->address_ = device_address;
964       }
965     } else {
966       UpdateAddressInfoByInputTensor(op_compiler_info, input_tensor, input_edge, input_node);
967       if (input_edge->ignore_h2d_) {
968         input_edge->address_->kernel_tensor()->SetValue(input_tensor);
969         MS_LOG(DEBUG) << "Ignore host to device for " << op_compiler_info->graph_info_;
970       } else {
971         input_tensor->set_device_address(input_edge->address_);
972       }
973     }
974   }
975   MS_LOG(DEBUG) << "End update input device address for " << op_compiler_info->graph_info_;
976 }
977 
CopyHostToDevice(const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)978 void DynamicOpRunner::CopyHostToDevice(const OpCompilerInfoPtr &op_compiler_info,
979                                        const std::vector<tensor::BaseTensorPtr> &input_tensors) {
980   const auto &input_edges = op_compiler_info->simple_graph_->inputs_;
981   auto input_tensors_num = input_tensors.size();
982   auto input_edge_num = input_edges.size();
983   if (input_tensors_num != input_edge_num) {
984     MS_LOG(EXCEPTION) << "Real input tensor's number " << input_tensors_num << " is not equal to input edges number "
985                       << input_edge_num << " !";
986   }
987 
988   const auto &device_context = op_compiler_info->device_context_;
989   for (size_t i = 0; i < input_tensors_num; ++i) {
990     const auto &input_tensor = input_tensors[i];
991     MS_EXCEPTION_IF_NULL(input_tensor);
992     const auto &device_sync = input_tensor->device_address();
993     const auto &device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
994 
995     const auto &input_edge = input_edges[i];
996     if (input_edge->ignore_h2d_) {
997       continue;
998     }
999 
1000     const auto &input_node = input_edge->node_with_index_.first;
1001     MS_EXCEPTION_IF_NULL(input_node);
1002     common::AnfAlgo::SetOutputInferTypeAndShape({input_tensor->data_type()}, {input_tensor->shape()}, input_node.get());
1003 
1004     if (device_address == nullptr) {
1005       MS_LOG(EXCEPTION) << "Input DeviceAddress cannot be null before copy host to device, op name "
1006                         << op_compiler_info->graph_info_;
1007     }
1008 
1009     if (device_address->GetMutablePtr() != nullptr) {
1010       continue;
1011     }
1012 
1013     auto mem_type =
1014       input_tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
1015     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
1016                                                    device_address.get());
1017     if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1018       MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
1019                         << ") memory isn't enough and alloc failed, kernel name: " << input_node->DebugString()
1020                         << ", alloc size: " << device_address->GetSize() << "B.";
1021     }
1022     if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), device_address->GetSize(),
1023                                           device_address->type_id(), "DefaultFormat", input_tensor->data_ptr())) {
1024       MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
1025     }
1026     MS_LOG(DEBUG) << "Copy host tensor to device for op " << op_compiler_info->graph_info_ << " input " << i;
1027   }
1028 }
1029 }  // namespace mindspore::runtime
1030