1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/pynative/op_runner.h"
18
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <map>
23 #include <unordered_map>
24 #include <algorithm>
25 #include <array>
26 #include "ops/structure_op_name.h"
27 #include "utils/log_adapter.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "include/backend/optimizer/helper.h"
30 #include "include/backend/device_type.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "runtime/device/ms_device_shape_transfer.h"
33 #include "runtime/device/device_address_utils.h"
34 #include "runtime/pynative/op_runtime_info.h"
35 #include "runtime/pynative/op_executor.h"
36 #include "runtime/pynative/op_compiler.h"
37 #include "runtime/graph_scheduler/actor/actor_common.h"
38 #include "kernel/framework_utils.h"
39 #include "include/backend/mem_reuse/mem_tracker.h"
40 #ifndef ENABLE_SECURITY
41 #include "include/backend/debug/profiler/profiling.h"
42 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
43 #include "pybind_api/gil_scoped_long_running.h"
44 #include "runtime/pynative/ir_converter.h"
45
46 using mindspore::profiler::ProfilerManager;
47 #endif
48 using EdgePtr = mindspore::pynative::EdgePtr;
49
50 namespace mindspore::runtime {
51 namespace {
52 constexpr size_t kContextSize = 4;
53 std::unique_ptr<std::mutex> kDeviceContextMutex = std::make_unique<std::mutex>();
54 std::array<DeviceContext *, kContextSize> kDeviceContexts = {nullptr, nullptr, nullptr, nullptr};
55
56 // 1. Device type is different in heterogeneous scenes.
57 // 2. The device address format is different.
UpdateInputTensorFromDevice(const std::vector<AnfNodePtr> & input_nodes,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context)58 void UpdateInputTensorFromDevice(const std::vector<AnfNodePtr> &input_nodes,
59 const std::vector<tensor::BaseTensorPtr> &input_tensors,
60 const device::DeviceContext *device_context) {
61 MS_LOG(DEBUG) << "Start";
62 auto input_size = input_nodes.size();
63 for (size_t i = 0; i < input_size; ++i) {
64 auto &tensor = input_tensors[i];
65 auto &input_node = input_nodes[i];
66 MS_EXCEPTION_IF_NULL(tensor);
67 auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
68 auto node_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
69 // node_address can't be null
70 MS_EXCEPTION_IF_NULL(node_address);
71 MS_EXCEPTION_IF_NULL(device_context);
72 if (tensor_address != nullptr) {
73 if (tensor_address->GetDeviceType() != device_context->GetDeviceType() ||
74 tensor_address->format() != node_address->format()) {
75 // Need wait for OpExecutor task finish
76 tensor->data_sync();
77 // If tensor address is null, we will set Parameter address to the Tensor.
78 tensor->set_device_address(nullptr);
79 }
80 }
81 }
82 MS_LOG(DEBUG) << "End";
83 }
84
UpdateParameterShapeFromInputTensor(const AnfNodePtr & input_node,const tensor::BaseTensorPtr & input_tensor)85 void UpdateParameterShapeFromInputTensor(const AnfNodePtr &input_node, const tensor::BaseTensorPtr &input_tensor) {
86 MS_EXCEPTION_IF_NULL(input_node);
87 if (input_tensor == nullptr || !input_node->isa<Parameter>()) {
88 return;
89 }
90
91 auto input_param = input_node->cast<ParameterPtr>();
92 MS_EXCEPTION_IF_NULL(input_param);
93 if (!input_param->has_dynamic_shape()) {
94 return;
95 }
96
97 auto shape = input_tensor->shape();
98 MS_LOG(DEBUG) << "Update input node shape to:" << shape;
99 common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape},
100 input_node.get());
101 }
102
SetDeviceAddress(const AnfNodePtr & input_node,const tensor::BaseTensorPtr & input_tensor,const device::DeviceContext * device_context,bool is_sync)103 void SetDeviceAddress(const AnfNodePtr &input_node, const tensor::BaseTensorPtr &input_tensor,
104 const device::DeviceContext *device_context, bool is_sync) {
105 MS_EXCEPTION_IF_NULL(input_tensor);
106 auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
107 auto node_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
108
109 UpdateParameterShapeFromInputTensor(input_node, input_tensor);
110
111 MS_EXCEPTION_IF_NULL(node_address);
112 if (tensor_address == nullptr) {
113 input_tensor->set_device_address(node_address);
114 input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
115 input_tensor->set_need_pipeline_sync(true);
116 node_address->set_from_persistent_mem(input_tensor->is_parameter());
117 node_address->SetNodeIndex(input_node, 0);
118 }
119
120 // The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
121 if (tensor_address != nullptr && tensor_address != node_address) {
122 auto address = tensor_address;
123 if (tensor_address->GetTensorStorageInfo() != nullptr) {
124 address = DeviceAddressUtils::ConvertContiguousDeviceAddress(device_context, tensor_address, is_sync);
125 input_tensor->set_device_address(address);
126 }
127 AnfAlgo::SetOutputAddr(address, 0, input_node.get());
128 }
129 }
130
UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> & input_nodes,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context,bool is_sync)131 void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
132 const std::vector<tensor::BaseTensorPtr> &input_tensors,
133 const device::DeviceContext *device_context, bool is_sync) {
134 MS_LOG(DEBUG) << "Start";
135 auto input_size = input_nodes.size();
136 auto tensor_size = input_tensors.size();
137 if (input_size != tensor_size) {
138 MS_LOG(EXCEPTION) << "input node size:" << input_size << " not equal to tensors size:" << tensor_size;
139 }
140 for (size_t i = 0; i < input_size; ++i) {
141 auto &input_node = input_nodes[i];
142 auto &input_tensor = input_tensors[i];
143 MS_EXCEPTION_IF_NULL(input_tensor);
144 if (input_tensor->isa<tensor::MapTensor>()) {
145 auto map_tensor = input_tensor->cast<tensor::MapTensorPtr>();
146 MS_EXCEPTION_IF_NULL(map_tensor);
147 SetDeviceAddress(input_node, map_tensor, device_context, is_sync);
148 SetDeviceAddress(input_node, map_tensor->key_tensor(), device_context, is_sync);
149 SetDeviceAddress(input_node, map_tensor->value_tensor(), device_context, is_sync);
150 SetDeviceAddress(input_node, map_tensor->status_tensor(), device_context, is_sync);
151 } else {
152 SetDeviceAddress(input_node, input_tensor, device_context, is_sync);
153 }
154 }
155 MS_LOG(DEBUG) << "End";
156 }
157
CopyTensorDataToDevice(const tensor::BaseTensorPtr & tensor,const AnfNodePtr & node,const device::DeviceContext * device_context)158 void CopyTensorDataToDevice(const tensor::BaseTensorPtr &tensor, const AnfNodePtr &node,
159 const device::DeviceContext *device_context) {
160 MS_EXCEPTION_IF_NULL(tensor);
161 MS_EXCEPTION_IF_NULL(device_context);
162 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
163 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
164 MS_EXCEPTION_IF_CHECK_FAIL(device_address != nullptr, "Tensor device address is nullptr, id is " + tensor->id());
165 // Break copy data to device address if has the device_address has flag ignore.
166 if (TEST_FLAG(device_address->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
167 MS_LOG(DEBUG) << "Node " << node->DebugString() << " with address " << device_address
168 << " has flag ignore device address, so skip copy tensor to device";
169 return;
170 }
171
172 auto mem_type = tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
173 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
174 device_address.get());
175 if ((device_address->GetPtr() == nullptr) &&
176 (!device_context->device_res_manager_->AllocateMemory(device_address.get()))) {
177 MS_LOG(EXCEPTION) << "Allocate memory failed, alloc size " << device_address->GetSize() << "B";
178 }
179 // Copy data from host tensor to device.
180 auto tensor_size = LongToSize(tensor->data().nbytes());
181 auto tensor_type = tensor->data_type();
182 MS_LOG(DEBUG) << "Copy to device, node:" << common::AnfAlgo::GetNodeDebugString(node);
183 if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), tensor_size, tensor_type,
184 "DefaultFormat", tensor->data_ptr())) {
185 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
186 }
187 }
188
CopyValueNodeDataToDevice(const KernelGraphPtr & graph,const device::DeviceContext * device_context)189 void CopyValueNodeDataToDevice(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
190 MS_EXCEPTION_IF_NULL(graph);
191 MS_LOG(DEBUG) << "Start";
192 const auto &value_nodes = graph->graph_value_nodes();
193 for (const auto &value_node : value_nodes) {
194 MS_EXCEPTION_IF_NULL(value_node);
195 const auto &node_value = value_node->value();
196 MS_EXCEPTION_IF_NULL(node_value);
197 if (!node_value->isa<tensor::BaseTensor>() && !node_value->isa<ValueTuple>() && !node_value->isa<Scalar>() &&
198 !node_value->isa<StringImm>()) {
199 MS_LOG(INFO) << "Unknown value node type:" << value_node->DebugString();
200 continue;
201 }
202
203 const auto &node_address = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
204 MS_EXCEPTION_IF_NULL(node_address);
205 node_address->SetNodeIndex(value_node, 0);
206 if (node_address->GetPtr() != nullptr) {
207 continue;
208 }
209 auto shape = trans::GetRuntimePaddingShape(value_node, 0);
210 runtime::DeviceAddressUtils::CopyNoneTensorDataToDevice(device_context, node_address, shape);
211 }
212 MS_LOG(DEBUG) << "End";
213 }
214
UpdateAddressSizeForDynamicShapeTensor(const tensor::BaseTensorPtr & input_tensor)215 void UpdateAddressSizeForDynamicShapeTensor(const tensor::BaseTensorPtr &input_tensor) {
216 MS_EXCEPTION_IF_NULL(input_tensor);
217 if (input_tensor->base_shape_ptr() != nullptr) {
218 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
219 MS_EXCEPTION_IF_NULL(device_address);
220 auto tensor_size = LongToSize(input_tensor->data().nbytes());
221 if (tensor_size != device_address->GetSize()) {
222 device_address->SetSize(tensor_size);
223 }
224 }
225 }
226
CopyMapTensorDataToDevice(const tensor::MapTensorPtr & map_tensor,const AnfNodePtr & input_node,const device::DeviceContext * device_context)227 void CopyMapTensorDataToDevice(const tensor::MapTensorPtr &map_tensor, const AnfNodePtr &input_node,
228 const device::DeviceContext *device_context) {
229 MS_EXCEPTION_IF_NULL(map_tensor);
230 auto key_tensor = map_tensor->key_tensor();
231 MS_EXCEPTION_IF_NULL(key_tensor);
232 UpdateAddressSizeForDynamicShapeTensor(key_tensor);
233 CopyTensorDataToDevice(key_tensor, input_node, device_context);
234 key_tensor->set_sync_status(kNoNeedSync);
235 auto value_tensor = map_tensor->value_tensor();
236 MS_EXCEPTION_IF_NULL(value_tensor);
237 UpdateAddressSizeForDynamicShapeTensor(value_tensor);
238 CopyTensorDataToDevice(value_tensor, input_node, device_context);
239 value_tensor->set_sync_status(kNoNeedSync);
240 auto status_tensor = map_tensor->status_tensor();
241 MS_EXCEPTION_IF_NULL(status_tensor);
242 UpdateAddressSizeForDynamicShapeTensor(status_tensor);
243 CopyTensorDataToDevice(status_tensor, input_node, device_context);
244 status_tensor->set_sync_status(kNoNeedSync);
245 }
246
CopyParameterDataToDevice(const std::vector<AnfNodePtr> & input_nodes,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context)247 void CopyParameterDataToDevice(const std::vector<AnfNodePtr> &input_nodes,
248 const std::vector<tensor::BaseTensorPtr> &input_tensors,
249 const device::DeviceContext *device_context) {
250 MS_LOG(DEBUG) << "Start";
251 auto input_size = input_nodes.size();
252 if (input_size > input_tensors.size()) {
253 MS_LOG(EXCEPTION) << "input_size is bigger than input_tensors size, input_size:" << input_size
254 << ", input_tensors size:" << input_tensors.size();
255 }
256 for (size_t i = 0; i < input_size; ++i) {
257 MS_EXCEPTION_IF_NULL(input_tensors[i]);
258 if (input_tensors[i]->NeedSyncHostToDeviceImmediately()) {
259 // First op in dynamic shape scenario(feed mode)
260 if (input_tensors[i]->isa<tensor::MapTensor>()) {
261 auto map_tensor = input_tensors[i]->cast<tensor::MapTensorPtr>();
262 CopyMapTensorDataToDevice(map_tensor, input_nodes[i], device_context);
263 } else {
264 UpdateAddressSizeForDynamicShapeTensor(input_tensors[i]);
265 CopyTensorDataToDevice(input_tensors[i], input_nodes[i], device_context);
266 input_tensors[i]->set_sync_status(kNoNeedSync);
267 }
268 }
269 }
270 MS_LOG(DEBUG) << "End";
271 }
272
MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const device::DeviceContext * device_context,const CNodePtr & node)273 bool MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
274 const device::DeviceContext *device_context, const CNodePtr &node) {
275 auto kernel_mod = AnfAlgo::GetKernelMod(node);
276 MS_EXCEPTION_IF_NULL(runtime_info);
277 MS_EXCEPTION_IF_NULL(device_context);
278 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
279 auto input_size = runtime_info->GetInputSize();
280 for (size_t i = 0; i < input_size; ++i) {
281 if (common::AnfAlgo::IsNoneInput(node, i)) {
282 MS_EXCEPTION_IF_NULL(node);
283 MS_LOG(DEBUG) << "Input [" << i << "] of " << node->fullname_with_scope() << " is None, no need to allocate.";
284 continue;
285 }
286 auto input_address = runtime_info->GetInputDeviceAddress(i);
287 MS_EXCEPTION_IF_NULL(kernel_mod);
288 MS_EXCEPTION_IF_NULL(input_address);
289 if (TEST_FLAG(input_address->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
290 MS_LOG(DEBUG) << "Node " << node->DebugString() << " input[" << i << "] with address " << input_address
291 << " has flag ignore device address, so skip malloc device address";
292 continue;
293 }
294 if (input_address->GetPtr() == nullptr) {
295 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
296 input_address->GetSize(), input_address.get());
297 if (!device_context->device_res_manager_->AllocateMemory(input_address.get())) {
298 MS_LOG(EXCEPTION) << "Allocate memory failed, alloc size " << input_address->GetSize() << "B";
299 }
300 }
301 }
302 return true;
303 }
304
MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const AnfNodePtr & node,const device::DeviceContext * device_context)305 bool MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> &runtime_info, const AnfNodePtr &node,
306 const device::DeviceContext *device_context) {
307 MS_EXCEPTION_IF_NULL(runtime_info);
308 MS_EXCEPTION_IF_NULL(node);
309 MS_EXCEPTION_IF_NULL(device_context);
310 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
311
312 auto kernel_mod = AnfAlgo::GetKernelMod(node);
313 MS_EXCEPTION_IF_NULL(kernel_mod);
314 auto output_size = runtime_info->GetOutputSize();
315 auto kernel_out_size_list = kernel_mod->GetOutputSizeList();
316 if (kernel_out_size_list.size() != output_size) {
317 MS_LOG(ERROR) << "Node " << node->fullname_with_scope() << " output num is:" << output_size
318 << " but kernel_mod output num:" << kernel_out_size_list.size();
319 return false;
320 }
321 for (size_t i = 0; i < output_size; ++i) {
322 auto device_address = runtime_info->GetOutputDeviceAddress(i);
323 MS_EXCEPTION_IF_NULL(device_address);
324 // For example, we need to call cudnnGetRNNTrainingReserveSize to get real output size in LstmGpuKernelMod!
325 if (kernel_out_size_list[i] != device_address->GetSize() &&
326 AnfAlgo::GetOutputFormat(node, i) == device_address->format()) {
327 // If the format of the DeviceAddress is different, then the size is originally different.
328 // Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size.
329 if (device_address->GetPtr() != nullptr) {
330 MS_LOG(ERROR) << "kernel mod output " << i << " size:" << kernel_out_size_list[i]
331 << " not equal to device_address size:" << device_address->GetSize()
332 << ", but the device address is already have ptr";
333 return false;
334 }
335 device_address->SetSize(kernel_out_size_list[i]);
336 }
337 if (device_address->GetPtr() == nullptr) {
338 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
339 device_address->GetSize(), device_address.get());
340 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
341 MS_LOG(EXCEPTION) << "Allocate output memory failed, alloc node:" << node->fullname_with_scope()
342 << " alloc size:" << device_address->GetSize() << "B";
343 }
344 }
345 }
346 return true;
347 }
348
GetInputKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const AnfNodePtr & node)349 std::vector<kernel::KernelTensor *> GetInputKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
350 const AnfNodePtr &node) {
351 MS_EXCEPTION_IF_NULL(runtime_info);
352 auto input_size = runtime_info->GetInputSize();
353 std::vector<kernel::KernelTensor *> inputs;
354 for (size_t i = 0; i < input_size; ++i) {
355 auto device_address = runtime_info->GetInputDeviceAddress(i);
356 MS_EXCEPTION_IF_NULL(device_address);
357 (void)inputs.emplace_back(device_address->kernel_tensor().get());
358 MS_EXCEPTION_IF_NULL(inputs.back());
359 MS_LOG(DEBUG) << "input[" << i << "]:" << inputs.back()->device_ptr() << " size:" << inputs.back()->size();
360 }
361 return inputs;
362 }
363
GetInputKernelTensorsForInfer(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const AnfNodePtr & node)364 std::vector<abstract::AbstractBasePtr> GetInputKernelTensorsForInfer(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
365 const AnfNodePtr &node) {
366 MS_EXCEPTION_IF_NULL(runtime_info);
367 auto input_size = runtime_info->GetInputSize();
368 std::vector<abstract::AbstractBasePtr> inputs;
369 for (size_t i = 0; i < input_size; ++i) {
370 auto device_address = runtime_info->GetInputDeviceAddress(i);
371 MS_EXCEPTION_IF_NULL(device_address);
372 (void)inputs.emplace_back(device_address->kernel_tensor());
373 MS_EXCEPTION_IF_NULL(inputs.back());
374 }
375 return inputs;
376 }
377
GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const device::DeviceContext * device_context,size_t workspace_size,size_t workspace_sizes)378 std::vector<kernel::KernelTensor *> GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
379 const device::DeviceContext *device_context,
380 size_t workspace_size, size_t workspace_sizes) {
381 std::vector<kernel::KernelTensor *> workspaces;
382 for (size_t i = 0; i < workspace_size && i < workspace_sizes; ++i) {
383 auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
384 MS_EXCEPTION_IF_NULL(device_address);
385 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
386 device_address->GetSize(), device_address.get());
387 if (device_address->GetPtr() == nullptr &&
388 !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
389 MS_LOG(EXCEPTION) << "Allocate workspace memory failed, alloc size:" << device_address->GetSize() << "B";
390 }
391 (void)workspaces.emplace_back(device_address->kernel_tensor().get());
392 MS_EXCEPTION_IF_NULL(workspaces.back());
393 MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->device_ptr()
394 << " size:" << workspaces.back()->size();
395 }
396 return workspaces;
397 }
398
GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info,const device::DeviceContext * device_context,const CNodePtr & kernel,bool is_dynamic_shape,bool is_dynamic_value)399 std::vector<kernel::KernelTensor *> GetWorkspaceKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
400 const device::DeviceContext *device_context,
401 const CNodePtr &kernel, bool is_dynamic_shape,
402 bool is_dynamic_value) {
403 MS_EXCEPTION_IF_NULL(runtime_info);
404 MS_EXCEPTION_IF_NULL(device_context);
405 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
406 auto workspace_size = runtime_info->GetWorkspaceSize();
407 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
408 MS_EXCEPTION_IF_NULL(kernel_mod);
409 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
410
411 std::vector<device::DeviceAddressPtr> add_workspaces;
412 if (is_dynamic_shape || is_dynamic_value) {
413 // Resize of workspaces, because of the dynamic size of workspace.
414 if (workspace_size < workspace_sizes.size()) {
415 for (size_t i = workspace_size; i < workspace_sizes.size(); ++i) {
416 auto kernel_tensor = std::make_shared<KernelTensor>(
417 nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
418 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
419 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
420 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
421 << " addr:" << device_address;
422 AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); // set to kernel_info
423 MS_EXCEPTION_IF_NULL(device_address);
424 (void)add_workspaces.emplace_back(device_address);
425 }
426 }
427 }
428
429 // Set workspace address new size
430 for (size_t i = 0; i < workspace_size && i < workspace_sizes.size(); ++i) {
431 auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
432 MS_EXCEPTION_IF_NULL(device_address);
433 device_address->SetSize(workspace_sizes[i]);
434 }
435
436 std::vector<kernel::KernelTensor *> workspaces =
437 GetWorkspaceKernelTensors(runtime_info, device_context, workspace_size, workspace_sizes.size());
438 for (size_t i = workspace_size; i < workspace_sizes.size(); ++i) {
439 auto device_address = add_workspaces[i];
440 MS_EXCEPTION_IF_NULL(device_address);
441 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
442 device_address->GetSize(), device_address.get());
443 if (device_address->GetPtr() == nullptr &&
444 !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
445 MS_LOG(EXCEPTION) << "Allocate workspace memory failed, alloc size:" << device_address->GetSize() << "B";
446 }
447 (void)workspaces.emplace_back(device_address->kernel_tensor().get());
448 MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->device_ptr()
449 << " size:" << workspaces.back()->size();
450 }
451 return workspaces;
452 }
453
GetWorkspaceKernelTensorsDynamic(const device::DeviceContext * device_context,const CNodePtr & kernel,std::vector<device::DeviceAddressPtr> * workspace_device_address)454 std::vector<kernel::KernelTensor *> GetWorkspaceKernelTensorsDynamic(
455 const device::DeviceContext *device_context, const CNodePtr &kernel,
456 std::vector<device::DeviceAddressPtr> *workspace_device_address) {
457 MS_EXCEPTION_IF_NULL(device_context);
458 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
459 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
460 MS_EXCEPTION_IF_NULL(kernel_mod);
461 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
462
463 std::vector<kernel::KernelTensor *> workspaces;
464 workspaces.reserve(workspace_sizes.size());
465 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
466 auto kernel_tensor = std::make_shared<KernelTensor>(
467 nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
468 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
469 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
470 MS_EXCEPTION_IF_NULL(device_address);
471 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
472 device_address->GetSize(), device_address.get());
473 if (device_address->GetPtr() == nullptr &&
474 !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
475 MS_LOG(EXCEPTION) << "Allocate dynamic workspace memory failed, alloc size:" << device_address->GetSize() << "B";
476 }
477 MS_EXCEPTION_IF_NULL(workspace_device_address);
478 (void)workspace_device_address->emplace_back(device_address);
479 (void)workspaces.emplace_back(device_address->kernel_tensor().get());
480 MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->device_ptr()
481 << " size:" << workspaces.back()->size();
482 }
483 return workspaces;
484 }
485
GetOutputKernelTensors(const std::shared_ptr<OpRuntimeInfo> & runtime_info)486 std::vector<kernel::KernelTensor *> GetOutputKernelTensors(const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
487 MS_EXCEPTION_IF_NULL(runtime_info);
488 auto output_size = runtime_info->GetOutputSize();
489 std::vector<kernel::KernelTensor *> outputs;
490 for (size_t i = 0; i < output_size; ++i) {
491 auto device_address = runtime_info->GetOutputDeviceAddress(i);
492 MS_EXCEPTION_IF_NULL(device_address);
493 (void)outputs.emplace_back(device_address->kernel_tensor().get());
494 MS_LOG(DEBUG) << "output[" << i << "]:" << outputs.back()->device_ptr() << " size:" << outputs.back()->size();
495 }
496 return outputs;
497 }
498
499 // Host to Device or Device to Host
CopyDataToDevice(const KernelGraphPtr & graph,const std::vector<tensor::BaseTensorPtr> & input_tensors,const device::DeviceContext * device_context)500 void CopyDataToDevice(const KernelGraphPtr &graph, const std::vector<tensor::BaseTensorPtr> &input_tensors,
501 const device::DeviceContext *device_context) {
502 MS_EXCEPTION_IF_NULL(graph);
503 CopyValueNodeDataToDevice(graph, device_context);
504 CopyParameterDataToDevice(graph->input_nodes(), input_tensors, device_context);
505 }
506
InferNodeRealShape(const CNodePtr & kernel,const std::vector<abstract::AbstractBasePtr> & input_args)507 BaseShapePtr InferNodeRealShape(const CNodePtr &kernel, const std::vector<abstract::AbstractBasePtr> &input_args) {
508 MS_EXCEPTION_IF_NULL(kernel);
509 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInfer,
510 kernel->fullname_with_scope(), false);
511 auto *kernel_mod = AnfAlgo::GetKernelMod(kernel);
512 MS_EXCEPTION_IF_NULL(kernel_mod);
513 return opt::dynamic_shape::InferShape(kernel_mod->primitive(), input_args);
514 }
515
ResizeKernelMod(const CNodePtr & kernel,const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs)516 void ResizeKernelMod(const CNodePtr &kernel, const std::vector<kernel::KernelTensor *> &inputs,
517 const std::vector<kernel::KernelTensor *> &outputs) {
518 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelResize,
519 kernel->fullname_with_scope(), false);
520 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
521 MS_EXCEPTION_IF_NULL(kernel_mod);
522 kernel_mod->set_use_kernel_tensor(true);
523
524 int ret = kernel_mod->Resize(inputs, outputs);
525 if (ret != kernel::KRET_OK) {
526 MS_LOG(EXCEPTION) << "Resize failed for kernel: " << kernel->fullname_with_scope();
527 }
528 }
529
SetOutputDeviceAddressFlag(const pynative::OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info)530 void SetOutputDeviceAddressFlag(const pynative::OpCompilerInfoPtr &op_compiler_info,
531 const session::BackendOpRunInfoPtr &op_run_info) {
532 MS_EXCEPTION_IF_NULL(op_compiler_info);
533 MS_EXCEPTION_IF_NULL(op_run_info);
534 const auto &simple_graph = op_compiler_info->simple_graph_;
535 size_t output_size = simple_graph->outputs_.size();
536 // Reset grad output flag.
537 const auto &outputs = simple_graph->outputs_;
538 for (const auto &output : outputs) {
539 output->is_grad_ = false;
540 }
541
542 if (op_run_info->is_gradient_out) {
543 const auto &output_indexes = op_run_info->base_op_run_info.output_indexes;
544 for (auto index : output_indexes) {
545 if (index >= output_size) {
546 MS_LOG(EXCEPTION) << "Gradient output index " << index << " >= graph output size " << output_size;
547 }
548 const auto &output = outputs[index];
549 MS_EXCEPTION_IF_NULL(output);
550 output->is_grad_ = true;
551 MS_LOG(DEBUG) << "Set grad flag for op " << op_run_info->base_op_run_info.op_name << " index " << index;
552 }
553 }
554 }
555
MallocForConstValue(const pynative::OpCompilerInfoPtr & op_compiler_info)556 void MallocForConstValue(const pynative::OpCompilerInfoPtr &op_compiler_info) {
557 MS_EXCEPTION_IF_NULL(op_compiler_info);
558 const auto &device_context = op_compiler_info->device_context_;
559 const auto &graph = op_compiler_info->graph_;
560 CopyValueNodeDataToDevice(graph, device_context);
561 }
562
UpdateOutputShape(const std::vector<EdgePtr> & output_edges)563 void UpdateOutputShape(const std::vector<EdgePtr> &output_edges) {
564 for (const auto &edge : output_edges) {
565 MS_EXCEPTION_IF_NULL(edge);
566 const auto &device_address = edge->address_;
567 MS_EXCEPTION_IF_NULL(device_address);
568 const auto &kernel_tensor = device_address->kernel_tensor();
569 MS_EXCEPTION_IF_NULL(kernel_tensor);
570 device_address->set_host_shape(kernel_tensor->host_info_exist() ? kernel_tensor->GetShapeVector()
571 : kernel_tensor->host_shape());
572 }
573 }
574
LaunchKernels(const KernelGraphPtr & graph,const device::DeviceContext * device_context,const session::BackendOpRunInfoPtr & op_run_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)575 void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *device_context,
576 const session::BackendOpRunInfoPtr &op_run_info,
577 const std::vector<tensor::BaseTensorPtr> &input_tensors) {
578 MS_EXCEPTION_IF_NULL(graph);
579 MS_EXCEPTION_IF_NULL(device_context);
580 MS_LOG(DEBUG) << "Start";
581
582 // Get device address from OpRuntimeInfo
583 const auto &execution_order = graph->execution_order();
584 for (auto const &node : execution_order) {
585 MS_EXCEPTION_IF_NULL(node);
586 MS_LOG(DEBUG) << "Start launch kernel " << node->fullname_with_scope() << " kernel type "
587 << AnfAlgo::GetKernelType(node);
588 auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(node);
589 bool is_dynamic_value = common::AnfAlgo::IsDynamicValue(node);
590 auto runtime_info = node->user_data<runtime::OpRuntimeInfo>();
591 MS_EXCEPTION_IF_NULL(runtime_info);
592
593 if (!MallocForKernelInput(runtime_info, device_context, node)) {
594 MS_LOG(EXCEPTION) << "Malloc for kernel input failed, Memory isn't enough, node:" << node->fullname_with_scope();
595 }
596
597 auto inputs = GetInputKernelTensors(runtime_info, node);
598 auto outputs = GetOutputKernelTensors(runtime_info);
599 if (is_dynamic_shape) {
600 auto input_kernel_tensors_for_infer = GetInputKernelTensorsForInfer(runtime_info, node);
601 auto out_shape = InferNodeRealShape(node, input_kernel_tensors_for_infer);
602 opt::dynamic_shape::UpdateKernelTensorShape(out_shape, outputs);
603 ResizeKernelMod(node, inputs, outputs);
604 } else if (is_dynamic_value) {
605 auto kernel_mod = runtime_info->GetKernelMod();
606 MS_EXCEPTION_IF_NULL(kernel_mod);
607 if (kernel_mod->Resize(inputs, outputs) != static_cast<int>(kernel::KRET_OK)) {
608 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " resize failed";
609 }
610 }
611 auto workspaces = GetWorkspaceKernelTensors(runtime_info, device_context, node, is_dynamic_shape, is_dynamic_value);
612
613 if (!MallocForKernelOutput(runtime_info, node, device_context)) {
614 MS_LOG(EXCEPTION) << "Malloc for kernel output failed, Memory isn't enough, node:" << node->fullname_with_scope();
615 }
616
617 MS_EXCEPTION_IF_NULL(device_context);
618 MS_EXCEPTION_IF_NULL(device_context->GetKernelExecutor(true));
619 auto kernel_mod = AnfAlgo::GetKernelMod(node);
620 const size_t stream_id = op_run_info->base_op_run_info.stream_id;
621 auto stream = device_context->device_res_manager_->GetStream(stream_id);
622 if (!device_context->GetKernelExecutor(false)->LaunchKernel(node, inputs, workspaces, outputs, kernel_mod,
623 stream)) {
624 MS_LOG(EXCEPTION) << "Launch kernel failed, name:" << node->fullname_with_scope();
625 }
626 runtime::DeviceAddressUtils::ProcessCrossStreamAddress(op_run_info->base_op_run_info.op_name, device_context,
627 stream_id, inputs, outputs);
628 }
629 MS_LOG(DEBUG) << "End";
630 }
631
AllocateOutputMemory(const std::vector<EdgePtr> & output_edges,const device::DeviceContext * device_context)632 void AllocateOutputMemory(const std::vector<EdgePtr> &output_edges, const device::DeviceContext *device_context) {
633 MS_EXCEPTION_IF_NULL(device_context);
634 for (const auto &edge : output_edges) {
635 MS_EXCEPTION_IF_NULL(edge);
636 const auto &device_address = edge->address_;
637 MS_EXCEPTION_IF_NULL(device_address);
638 if (device_address->GetPtr() == nullptr) {
639 if (edge->is_grad_) {
640 device_address->set_from_persistent_mem(true);
641 }
642 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
643 device_address->GetSize(), device_address.get());
644 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
645 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
646 MS_LOG(EXCEPTION) << "Allocate device memory failed, alloc size:" << device_address->GetSize() << "B";
647 }
648 }
649 }
650 }
651
UpdateOutputDeviceInfo(const std::vector<EdgePtr> & edges,const CNodePtr & kernel)652 void UpdateOutputDeviceInfo(const std::vector<EdgePtr> &edges, const CNodePtr &kernel) {
653 MS_EXCEPTION_IF_NULL(kernel);
654 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
655 MS_EXCEPTION_IF_NULL(kernel_mod);
656 auto output_size_list = kernel_mod->GetOutputSizeList();
657 if (edges.size() != output_size_list.size()) {
658 MS_LOG(EXCEPTION) << "Output device address's size " << edges.size() << " is not equal output_size_list's size "
659 << output_size_list.size();
660 }
661
662 auto output_num = edges.size();
663 for (size_t i = 0; i < output_num; ++i) {
664 const auto &edge = edges[i];
665 MS_EXCEPTION_IF_NULL(edge);
666 const auto &device_address = edge->address_;
667 MS_EXCEPTION_IF_NULL(device_address);
668 const auto &kernel_tensor = device_address->kernel_tensor();
669 MS_EXCEPTION_IF_NULL(kernel_tensor);
670 device_address->set_host_shape(kernel_tensor->GetShapeVector());
671 device_address->SetSize(output_size_list[i]);
672 }
673 }
674
UpdateInputTensorForHeterogeneous(const DeviceContext * device_context,const tensor::BaseTensorPtr & input_tensor,const device::DeviceAddressPtr & cached_device_address)675 void UpdateInputTensorForHeterogeneous(const DeviceContext *device_context, const tensor::BaseTensorPtr &input_tensor,
676 const device::DeviceAddressPtr &cached_device_address) {
677 MS_EXCEPTION_IF_NULL(device_context);
678 MS_EXCEPTION_IF_NULL(cached_device_address);
679 MS_EXCEPTION_IF_NULL(input_tensor);
680 auto device_sync = input_tensor->device_address();
681 if (device_sync == nullptr) {
682 return;
683 }
684 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
685 MS_EXCEPTION_IF_NULL(device_address);
686 if (device_address->GetDeviceType() != device_context->GetDeviceType() ||
687 device_address->format() != cached_device_address->format()) {
688 // Need wait for OpExecutor task finish
689 input_tensor->data_sync();
690 // If tensor address is null, we will set Parameter address to the Tensor.
691 input_tensor->set_device_address(nullptr);
692 }
693 }
694
UpdateAddressInfoByInputTensor(const OpCompilerInfoPtr & op_compiler_info,const tensor::BaseTensorPtr & tensor,const EdgePtr & edge,const AnfNodePtr & node)695 void UpdateAddressInfoByInputTensor(const OpCompilerInfoPtr &op_compiler_info, const tensor::BaseTensorPtr &tensor,
696 const EdgePtr &edge, const AnfNodePtr &node) {
697 MS_EXCEPTION_IF_NULL(tensor);
698 MS_EXCEPTION_IF_NULL(node);
699 auto &device_context = op_compiler_info->device_context_;
700 MS_EXCEPTION_IF_NULL(device_context);
701 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
702
703 auto origin_address = edge->origin_address_;
704
705 const auto &format = origin_address->format();
706 const auto dtype = origin_address->type_id();
707 const auto &shape = tensor->shape();
708 size_t tensor_size = DeviceAddressUtils::GetTensorDeviceSize(device_context, node, shape, format, dtype, 0);
709
710 const auto &kernel_tensor = origin_address->kernel_tensor();
711 MS_EXCEPTION_IF_NULL(kernel_tensor);
712 auto new_kernel_tensor = kernel_tensor->CloneKernelTensor();
713 MS_EXCEPTION_IF_NULL(new_kernel_tensor);
714
715 new_kernel_tensor->SetShapeVector(shape);
716 new_kernel_tensor->set_device_ptr(nullptr);
717 auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
718 MS_EXCEPTION_IF_NULL(new_device_address);
719 new_device_address->set_host_shape(shape);
720 new_device_address->SetSize(tensor_size);
721 new_device_address->set_from_persistent_mem(tensor->is_parameter());
722 edge->address_ = new_device_address;
723 }
724
GetInputKernelTensors(const std::vector<EdgePtr> & edges)725 std::vector<kernel::KernelTensor *> GetInputKernelTensors(const std::vector<EdgePtr> &edges) {
726 std::vector<kernel::KernelTensor *> input_kernel_tensors;
727 input_kernel_tensors.reserve(edges.size());
728 (void)std::transform(edges.begin(), edges.end(), std::back_inserter(input_kernel_tensors), [](const EdgePtr &edge) {
729 MS_EXCEPTION_IF_NULL(edge->address_);
730 return edge->address_->kernel_tensor().get();
731 });
732 return input_kernel_tensors;
733 }
734
GetInputInferAbstract(const std::vector<EdgePtr> & edges)735 std::vector<abstract::AbstractBasePtr> GetInputInferAbstract(const std::vector<EdgePtr> &edges) {
736 std::vector<abstract::AbstractBasePtr> input_abstracts;
737 input_abstracts.reserve(edges.size());
738 (void)std::transform(edges.begin(), edges.end(), std::back_inserter(input_abstracts), [](const EdgePtr &edge) {
739 MS_EXCEPTION_IF_NULL(edge->address_);
740 return edge->address_->kernel_tensor();
741 });
742 return input_abstracts;
743 }
744
GetOutputKernelTensors(const std::vector<EdgePtr> & edges,const DeviceContext * device_context)745 std::vector<kernel::KernelTensor *> GetOutputKernelTensors(const std::vector<EdgePtr> &edges,
746 const DeviceContext *device_context) {
747 std::vector<kernel::KernelTensor *> output_kernel_tensors;
748 output_kernel_tensors.reserve(edges.size());
749 for (const auto &edge : edges) {
750 // For example, output is dynamic or the output is between two ops.
751 if (edge->address_ == nullptr) {
752 edge->address_ = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(edge->origin_address_, device_context);
753 }
754 const auto &output_address = edge->address_;
755 MS_EXCEPTION_IF_NULL(output_address);
756 output_kernel_tensors.push_back(output_address->kernel_tensor().get());
757 }
758 return output_kernel_tensors;
759 }
760 } // namespace
761
GetTensorWithoutValueMask(const session::BackendOpRunInfoPtr & op_run_info)762 std::vector<tensor::BaseTensorPtr> OpRunner::GetTensorWithoutValueMask(
763 const session::BackendOpRunInfoPtr &op_run_info) {
764 MS_EXCEPTION_IF_NULL(op_run_info);
765 std::vector<tensor::BaseTensorPtr> tensors_without_value_node;
766 const auto &input_values = op_run_info->base_op_run_info.expanded_input_values;
767 const auto &input_masks = op_run_info->base_op_run_info.input_types;
768 if (input_values.size() != input_masks.size()) {
769 MS_LOG(EXCEPTION) << "Input tensors size " << input_values.size() << " should be equal to tensors mask size "
770 << input_masks.size();
771 }
772 for (size_t index = 0; index < input_masks.size(); ++index) {
773 runtime::DeviceAddressUtils::CreateKernelTensor(input_values[index]);
774 if (input_masks.at(index) != InputType::kConstant) {
775 if (!input_values[index]->isa<tensor::BaseTensor>()) {
776 MS_LOG(EXCEPTION) << "The " << index << "' input shoulde be a Tensor, but got "
777 << input_values[index]->ToString();
778 }
779 (void)tensors_without_value_node.emplace_back(input_values.at(index)->cast<tensor::BaseTensorPtr>());
780 }
781 }
782 return tensors_without_value_node;
783 }
784
785 // Determine the address of the graph and do not change the address in subsequent executions
UpdateDeviceAddress(const KernelGraphPtr & graph,const std::vector<tensor::BaseTensorPtr> & tensors_without_value_mask,const device::DeviceContext * device_context,bool is_sync)786 void OpRunner::UpdateDeviceAddress(const KernelGraphPtr &graph,
787 const std::vector<tensor::BaseTensorPtr> &tensors_without_value_mask,
788 const device::DeviceContext *device_context, bool is_sync) {
789 MS_EXCEPTION_IF_NULL(graph);
790 MS_LOG(DEBUG) << "Start";
791 const auto &input_nodes = graph->input_nodes();
792 UpdateInputTensorFromDevice(input_nodes, tensors_without_value_mask, device_context);
793 UpdateInputNodeDeviceAddress(input_nodes, tensors_without_value_mask, device_context, is_sync);
794 pynative::OpCompiler::UpdateRefNodeOutputDeviceAddress(graph);
795 MS_LOG(DEBUG) << "End";
796 }
797
RunSingleOpGraph(const session::BackendOpRunInfoPtr & op_run_info,const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)798 void OpRunner::RunSingleOpGraph(const session::BackendOpRunInfoPtr &op_run_info,
799 const OpCompilerInfoPtr &op_compiler_info,
800 const std::vector<tensor::BaseTensorPtr> &input_tensors) {
801 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "PyNative", op_run_info->base_op_run_info.op_name,
802 op_compiler_info->graph_->ToString());
803 CopyDataToDevice(op_compiler_info->graph_, input_tensors, op_compiler_info->device_context_);
804 LaunchKernels(op_compiler_info->graph_, op_compiler_info->device_context_, op_run_info, input_tensors);
805 }
806
LaunchKernelTask(const runtime::KernelTaskType & task_type,DeviceContext * device_context,const device::DeviceAddressPtrList & input_addr_list,const device::DeviceAddressPtrList & output_addr_list,size_t stream_id)807 void OpRunner::LaunchKernelTask(const runtime::KernelTaskType &task_type, DeviceContext *device_context,
808 const device::DeviceAddressPtrList &input_addr_list,
809 const device::DeviceAddressPtrList &output_addr_list, size_t stream_id) {
810 MS_EXCEPTION_IF_NULL(device_context);
811 MS_LOG(DEBUG) << "Start, task_type:" << task_type;
812 if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(task_type, input_addr_list, output_addr_list,
813 stream_id)) {
814 MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << task_type;
815 }
816 MS_LOG(DEBUG) << "End";
817 }
818
GetDeviceContext(const std::string & device_type)819 DeviceContext *OpRunner::GetDeviceContext(const std::string &device_type) {
820 auto type_iter = device::device_name_to_type_map.find(device_type);
821 if (type_iter == device::device_name_to_type_map.end()) {
822 MS_LOG(EXCEPTION) << "Invalid device_type " << device_type;
823 }
824
825 auto index = static_cast<size_t>(type_iter->second);
826 auto cached_device_context = kDeviceContexts[index];
827
828 if (cached_device_context != nullptr) {
829 return cached_device_context;
830 }
831
832 GilReleaseWithCheck release_gil;
833 std::unique_lock<std::mutex> lock(*kDeviceContextMutex);
834
835 auto device_id = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
836 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_type, device_id});
837 MS_EXCEPTION_IF_NULL(device_context);
838 device_context->Initialize();
839
840 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
841 (void)device_context->device_res_manager_->BindDeviceToCurrentThread(false);
842 kDeviceContexts[index] = device_context;
843 MS_LOG(DEBUG) << "Get device context of " << device_type << " id " << device_id;
844 return device_context;
845 }
846
ChildAfterFork()847 void OpRunner::ChildAfterFork() {
848 kDeviceContexts.fill(nullptr);
849 kDeviceContextMutex = std::make_unique<std::mutex>();
850 }
851
RunSingleOpGraph(const session::BackendOpRunInfoPtr & op_run_info,const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)852 void DynamicOpRunner::RunSingleOpGraph(const session::BackendOpRunInfoPtr &op_run_info,
853 const OpCompilerInfoPtr &op_compiler_info,
854 const std::vector<tensor::BaseTensorPtr> &input_tensors) {
855 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "PyNative", op_run_info->base_op_run_info.op_name,
856 op_compiler_info->graph_->ToString());
857 DynamicOpRunner::CopyHostToDevice(op_compiler_info, input_tensors);
858 MallocForConstValue(op_compiler_info);
859
860 const auto &simple_graph = op_compiler_info->simple_graph_;
861 const auto &single_ops = simple_graph->single_ops_;
862 bool is_need_infer = false;
863 auto op_num = single_ops.size();
864 MS_EXCEPTION_IF_NULL(op_run_info->base_op_run_info.abstract);
865 if (op_num > 1 || op_run_info->base_op_run_info.abstract->BuildShape()->IsDynamic()) {
866 is_need_infer = true;
867 }
868
869 SetOutputDeviceAddressFlag(op_compiler_info, op_run_info);
870
871 const auto *device_context = op_compiler_info->device_context_;
872 // Execute all kernels
873 for (size_t i = 0; i < op_num; ++i) {
874 const auto &single_op = single_ops[i];
875 const CNodePtr &kernel = single_op->kernel_;
876 MS_EXCEPTION_IF_NULL(kernel);
877
878 // Fetch input kernel tensor.
879 const auto &input_edges = single_op->inputs_;
880 const auto &output_edges = single_op->outputs_;
881
882 const auto &input_kernel_tensors = GetInputKernelTensors(input_edges);
883 const auto &input_abstracts = GetInputInferAbstract(input_edges);
884 const auto &output_kernel_tensors = GetOutputKernelTensors(output_edges, device_context);
885
886 BaseShapePtr out_shape;
887 if (is_need_infer) {
888 out_shape = InferNodeRealShape(kernel, input_abstracts);
889 } else {
890 kernel->set_abstract(op_run_info->base_op_run_info.abstract);
891 out_shape = op_run_info->base_op_run_info.abstract->GetShape();
892 }
893 // Update output kernel tensor.
894 opt::dynamic_shape::UpdateKernelTensorShape(out_shape, output_kernel_tensors);
895
896 // Resize
897 ResizeKernelMod(kernel, input_kernel_tensors, output_kernel_tensors);
898
899 // Malloc workspace memory
900 std::vector<device::DeviceAddressPtr> workspace_device_address;
901 auto workspace_kernel_tensors = GetWorkspaceKernelTensorsDynamic(device_context, kernel, &workspace_device_address);
902
903 // Update output tensor shape
904 UpdateOutputDeviceInfo(output_edges, kernel);
905
906 // Malloc output tensor memory
907 AllocateOutputMemory(output_edges, device_context);
908
909 // Launch kernel
910 MS_EXCEPTION_IF_NULL(device_context);
911 MS_EXCEPTION_IF_NULL(device_context->GetKernelExecutor(true));
912 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
913 MS_EXCEPTION_IF_NULL(kernel_mod);
914 const size_t stream_id = op_run_info->base_op_run_info.stream_id;
915 auto stream = device_context->device_res_manager_->GetStream(stream_id);
916 if (!device_context->GetKernelExecutor(true)->LaunchKernel(kernel, input_kernel_tensors, workspace_kernel_tensors,
917 output_kernel_tensors, kernel_mod, stream)) {
918 MS_LOG(EXCEPTION) << "Launch kernel failed, name:" << kernel->fullname_with_scope();
919 }
920
921 if (is_need_infer) {
922 if (kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
923 kernel_mod->UpdateOutputShapeAndSize(input_kernel_tensors, output_kernel_tensors);
924 UpdateOutputShape(output_edges);
925 }
926 }
927 runtime::DeviceAddressUtils::ProcessCrossStreamAddress(op_run_info->base_op_run_info.op_name, device_context,
928 stream_id, input_kernel_tensors, output_kernel_tensors);
929 }
930 }
931
UpdateInputDeviceAddress(const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors,bool is_sync)932 void DynamicOpRunner::UpdateInputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info,
933 const std::vector<tensor::BaseTensorPtr> &input_tensors, bool is_sync) {
934 MS_LOG(DEBUG) << "Start update input device address for " << op_compiler_info->graph_info_;
935 const auto &simple_graph = op_compiler_info->simple_graph_;
936 auto input_tensors_num = input_tensors.size();
937 auto op_input_num = simple_graph->inputs_.size();
938 if (input_tensors_num != op_input_num) {
939 MS_LOG(EXCEPTION) << "Real input tensor's num " << input_tensors_num << " is not equal to op input num"
940 << op_input_num << " !";
941 }
942 const auto &device_context = op_compiler_info->device_context_;
943 const auto &inputs = simple_graph->inputs_;
944 for (size_t i = 0; i < input_tensors_num; ++i) {
945 const auto &input_tensor = input_tensors[i];
946 MS_EXCEPTION_IF_NULL(input_tensor);
947 const auto &input_edge = inputs[i];
948 // input_edge->address_ is null.
949 UpdateInputTensorForHeterogeneous(device_context, input_tensor, input_edge->origin_address_);
950 const auto &device_sync = input_tensor->device_address();
951 const auto &device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
952
953 const auto &input_node = input_edge->node_with_index_.first;
954 common::AnfAlgo::SetOutputInferTypeAndShape({input_tensor->data_type()}, {input_tensor->shape()}, input_node.get());
955 if (device_address != nullptr) {
956 if (device_address->GetTensorStorageInfo() != nullptr) {
957 auto new_device_address =
958 DeviceAddressUtils::ConvertContiguousDeviceAddress(device_context, device_address, is_sync);
959 input_edge->address_ = new_device_address;
960 input_tensor->set_device_address(new_device_address);
961 } else {
962 // Always use tensor address as kernel address.
963 input_edge->address_ = device_address;
964 }
965 } else {
966 UpdateAddressInfoByInputTensor(op_compiler_info, input_tensor, input_edge, input_node);
967 if (input_edge->ignore_h2d_) {
968 input_edge->address_->kernel_tensor()->SetValue(input_tensor);
969 MS_LOG(DEBUG) << "Ignore host to device for " << op_compiler_info->graph_info_;
970 } else {
971 input_tensor->set_device_address(input_edge->address_);
972 }
973 }
974 }
975 MS_LOG(DEBUG) << "End update input device address for " << op_compiler_info->graph_info_;
976 }
977
CopyHostToDevice(const OpCompilerInfoPtr & op_compiler_info,const std::vector<tensor::BaseTensorPtr> & input_tensors)978 void DynamicOpRunner::CopyHostToDevice(const OpCompilerInfoPtr &op_compiler_info,
979 const std::vector<tensor::BaseTensorPtr> &input_tensors) {
980 const auto &input_edges = op_compiler_info->simple_graph_->inputs_;
981 auto input_tensors_num = input_tensors.size();
982 auto input_edge_num = input_edges.size();
983 if (input_tensors_num != input_edge_num) {
984 MS_LOG(EXCEPTION) << "Real input tensor's number " << input_tensors_num << " is not equal to input edges number "
985 << input_edge_num << " !";
986 }
987
988 const auto &device_context = op_compiler_info->device_context_;
989 for (size_t i = 0; i < input_tensors_num; ++i) {
990 const auto &input_tensor = input_tensors[i];
991 MS_EXCEPTION_IF_NULL(input_tensor);
992 const auto &device_sync = input_tensor->device_address();
993 const auto &device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
994
995 const auto &input_edge = input_edges[i];
996 if (input_edge->ignore_h2d_) {
997 continue;
998 }
999
1000 const auto &input_node = input_edge->node_with_index_.first;
1001 MS_EXCEPTION_IF_NULL(input_node);
1002 common::AnfAlgo::SetOutputInferTypeAndShape({input_tensor->data_type()}, {input_tensor->shape()}, input_node.get());
1003
1004 if (device_address == nullptr) {
1005 MS_LOG(EXCEPTION) << "Input DeviceAddress cannot be null before copy host to device, op name "
1006 << op_compiler_info->graph_info_;
1007 }
1008
1009 if (device_address->GetMutablePtr() != nullptr) {
1010 continue;
1011 }
1012
1013 auto mem_type =
1014 input_tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
1015 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
1016 device_address.get());
1017 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1018 MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
1019 << ") memory isn't enough and alloc failed, kernel name: " << input_node->DebugString()
1020 << ", alloc size: " << device_address->GetSize() << "B.";
1021 }
1022 if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), device_address->GetSize(),
1023 device_address->type_id(), "DefaultFormat", input_tensor->data_ptr())) {
1024 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
1025 }
1026 MS_LOG(DEBUG) << "Copy host tensor to device for op " << op_compiler_info->graph_info_ << " input " << i;
1027 }
1028 }
1029 } // namespace mindspore::runtime
1030