1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/copy_actor.h"
18 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
19 #include "mindrt/include/async/async.h"
20 #include "utils/log_adapter.h"
21 #include "include/backend/mem_reuse/mem_tracker.h"
22
23 namespace mindspore {
24 namespace runtime {
25 const size_t kInputDeviceContextIndex = 0;
26 const size_t kOutputDeviceContextIndex = 1;
27
Init()28 void CopyActor::Init() {
29 // Check device contexts number.
30 if (device_contexts_.size() != device::kDeviceContextsNumTwo) {
31 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
32 }
33
34 const size_t kDeviceTensorNum = 1;
35 input_device_tensor_.resize(kDeviceTensorNum);
36 output_device_tensor_.resize(kDeviceTensorNum);
37
38 // Check output data index.
39 for (auto &data_arrow : output_data_arrows_) {
40 MS_EXCEPTION_IF_NULL(data_arrow);
41 if (IntToSize(data_arrow->from_output_index_) != 0) {
42 MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
43 }
44 }
45
46 InitOutputData();
47 }
48
Run(OpContext<DeviceTensor> * const context)49 void CopyActor::Run(OpContext<DeviceTensor> *const context) {
50 MS_EXCEPTION_IF_NULL(context);
51 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), GetAID().Name(), "");
52 FetchDeviceTensor(context);
53 SendMemoryAllocReq(context);
54 }
55
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)56 void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
57 if (ActorDispatcher::is_memory_allocation_sync()) {
58 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_,
59 device_contexts_[kOutputDeviceContextIndex], context, GetAID());
60 OnMemoryAllocFinish(context);
61 } else {
62 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_,
63 device_contexts_[kOutputDeviceContextIndex], context, GetAID());
64 }
65 }
66
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)67 void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
68 if (ActorDispatcher::is_memory_free_sync()) {
69 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_,
70 device_contexts_[kInputDeviceContextIndex], context, GetAID());
71 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_,
72 device_contexts_[kOutputDeviceContextIndex], context, GetAID());
73 } else {
74 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_,
75 device_contexts_[kInputDeviceContextIndex], context, GetAID());
76 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_,
77 device_contexts_[kOutputDeviceContextIndex], context, GetAID());
78 }
79 }
80
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)81 void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
82 MS_EXCEPTION_IF_NULL(context);
83 MS_EXCEPTION_IF_NULL(output_device_tensor_[0]);
84 MS_EXCEPTION_IF_NULL(input_device_tensor_[0]);
85 if (IsRunningFailed(context)) {
86 return;
87 }
88
89 if (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize()) {
90 MS_LOG(WARNING) << GetAID().Name() << " copy size is not equal, input size:" << input_device_tensor_[0]->GetSize()
91 << ", output size:" << output_device_tensor_[0]->GetSize();
92 }
93
94 {
95 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kCopyData, GetAID().Name());
96 if (!Copy(output_device_tensor_[0], input_device_tensor_[0])) {
97 std::string error_info = "Copy device tensor failed: " + GetAID().Name();
98 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
99 }
100 output_device_tensor_[0]->kernel_tensor()->SetType(input_device_tensor_[0]->kernel_tensor()->GetType());
101 output_device_tensor_[0]->kernel_tensor()->SetShape(input_device_tensor_[0]->kernel_tensor()->GetShape());
102 output_device_tensor_[0]->set_user_data(input_device_tensor_[0]->user_data());
103 MS_LOG(DEBUG) << "Set user data:" << input_device_tensor_[0]->user_data()
104 << " shape:" << input_device_tensor_[0]->kernel_tensor()->GetShape()->ToString()
105 << " from device tensor:" << input_device_tensor_[0]
106 << " to device address:" << output_device_tensor_[0];
107 output_device_tensor_[0]->set_need_sync_user_data(input_device_tensor_[0]->need_sync_user_data());
108 }
109
110 PostRun(context);
111 }
112
FetchDeviceTensor(OpContext<DeviceTensor> * const context)113 void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
114 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
115 MS_EXCEPTION_IF_NULL(context);
116 const auto &input_device_context = device_contexts_[kInputDeviceContextIndex];
117 const auto &output_device_context = device_contexts_[kOutputDeviceContextIndex];
118 MS_EXCEPTION_IF_NULL(input_device_context);
119 MS_EXCEPTION_IF_NULL(output_device_context);
120
121 if (device_tensor_store_keys_.size() > 0) {
122 const auto &device_tensor_store_node = device_tensor_store_keys_[0].second;
123 MS_EXCEPTION_IF_NULL(device_tensor_store_node);
124 input_device_tensor_[0] = DeviceTensorStore::GetInstance()
125 .Fetch(device_tensor_store_node.get(), input_device_context->GetDeviceType())
126 .get();
127 if (input_device_tensor_[0] == nullptr) {
128 std::string error_info =
129 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
130 ", device type:" + std::to_string(static_cast<int>(input_device_context->GetDeviceType()));
131 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
132 }
133
134 output_device_tensor_[0] = DeviceTensorStore::GetInstance()
135 .Fetch(device_tensor_store_node.get(), output_device_context->GetDeviceType())
136 .get();
137 if (output_device_tensor_[0] == nullptr) {
138 std::string error_info =
139 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
140 ", device type:" + std::to_string(static_cast<int>(output_device_context->GetDeviceType()));
141 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
142 }
143 } else {
144 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
145 if (data_iter == input_op_datas_.end()) {
146 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "No input data.");
147 }
148 const auto &input_data = data_iter->second[0];
149 MS_EXCEPTION_IF_NULL(input_data);
150 input_device_tensor_[0] = input_data->data_;
151
152 MS_EXCEPTION_IF_NULL(output_);
153 output_device_tensor_[0] = output_.get();
154 }
155
156 if (!WaitRuntimePipelineFinish(context)) {
157 MS_LOG(INFO) << "Run failed and early stop.";
158 return;
159 }
160 if (is_need_update_output_size_ && (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize())) {
161 MS_LOG(DEBUG) << GetAID().Name() << " update output size from " << output_device_tensor_[0]->GetSize() << " to "
162 << input_device_tensor_[0]->GetSize();
163 output_device_tensor_[0]->SetSize(input_device_tensor_[0]->GetSize());
164 const auto &output_kernel_tensor = output_device_tensor_[0]->kernel_tensor();
165 const auto &input_kernel_tensor = input_device_tensor_[0]->kernel_tensor();
166 MS_EXCEPTION_IF_NULL(output_kernel_tensor);
167 MS_EXCEPTION_IF_NULL(input_kernel_tensor);
168 output_kernel_tensor->SetType(input_kernel_tensor->GetType()->Clone());
169 output_kernel_tensor->SetShape(input_kernel_tensor->GetShape()->Clone());
170 }
171 }
172
UpdateOutputData(OpData<DeviceTensor> * const output_data,const DataArrowPtr &,const AnfNodePtr &,OpContext<DeviceTensor> * const)173 void CopyActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &, const AnfNodePtr &,
174 OpContext<DeviceTensor> *const) {
175 MS_EXCEPTION_IF_NULL(output_data);
176 output_data->data_ = output_device_tensor_[0];
177 }
178 } // namespace runtime
179 } // namespace mindspore
180