1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/actor/copy_actor.h"
18 #include "runtime/framework/actor/memory_manager_actor.h"
19 #include "mindrt/include/async/async.h"
20 #include "utils/log_adapter.h"
21
22 namespace mindspore {
23 namespace runtime {
24 const size_t kInputDeviceContextIndex = 0;
25 const size_t kOutputDeviceContextIndex = 1;
26
Init()27 void CopyActor::Init() {
28 // Check device contexts number.
29 if (device_contexts_.size() != device::kDeviceContextsNumTwo) {
30 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
31 }
32
33 const size_t kDeviceTensorNum = 1;
34 input_device_tensor_.resize(kDeviceTensorNum);
35 output_device_tensor_.resize(kDeviceTensorNum);
36
37 // Init output data.
38 for (auto &data_arrow : output_data_arrows_) {
39 MS_EXCEPTION_IF_NULL(data_arrow);
40 if (IntToSize(data_arrow->from_output_index_) != 0) {
41 MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
42 }
43 auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
44 (void)output_data_.emplace_back(std::move(data));
45 }
46 }
47
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)48 void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
49 MS_EXCEPTION_IF_NULL(context);
50 auto &sequential_num = context->sequential_num_;
51 (void)input_op_datas_[sequential_num].emplace_back(input_data);
52 // When all the inputs are collected, then allocate memory and callback copy.
53 if (CheckRunningCondition(context)) {
54 FetchDeviceTensor(context);
55 SendMemoryAllocReq(context);
56 }
57 }
58
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)59 void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
60 MS_EXCEPTION_IF_NULL(context);
61 auto &sequential_num = context->sequential_num_;
62 (void)input_op_controls_[sequential_num].emplace_back(input_control);
63 // When all the inputs are collected, then allocate memory and callback copy.
64 if (CheckRunningCondition(context)) {
65 FetchDeviceTensor(context);
66 SendMemoryAllocReq(context);
67 }
68 }
69
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)70 void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
71 Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_,
72 device_contexts_[kOutputDeviceContextIndex], context, GetAID());
73 }
74
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)75 void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
76 Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_,
77 device_contexts_[kInputDeviceContextIndex], context);
78 Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_,
79 device_contexts_[kOutputDeviceContextIndex], context);
80 }
81
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)82 void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
83 MS_EXCEPTION_IF_NULL(context);
84 MS_EXCEPTION_IF_NULL(output_device_tensor_[0]);
85 MS_EXCEPTION_IF_NULL(input_device_tensor_[0]);
86
87 if (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize()) {
88 MS_LOG(WARNING) << GetAID().Name() << " copy size is not equal, input size:" << input_device_tensor_[0]->GetSize()
89 << ", output size:" << output_device_tensor_[0]->GetSize();
90 }
91
92 if (!Copy(output_device_tensor_[0], input_device_tensor_[0])) {
93 std::string error_info = "Copy device tensor failed: " + GetAID().Name();
94 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
95 }
96
97 // The input is invalid and needs to be erased when finish copy.
98 EraseInput(context);
99
100 // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
101 // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of the
102 // current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully, the
103 // other is to ensure the execution order and avoid the illegal memory timing problem.
104 SendMemoryFreeReq(context);
105 SendOutput(context);
106 }
107
FetchDeviceTensor(OpContext<DeviceTensor> * const context)108 void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
109 MS_EXCEPTION_IF_NULL(context);
110 const auto &input_device_context = device_contexts_[kInputDeviceContextIndex];
111 const auto &output_device_context = device_contexts_[kOutputDeviceContextIndex];
112 MS_EXCEPTION_IF_NULL(input_device_context);
113 MS_EXCEPTION_IF_NULL(output_device_context);
114
115 if (device_tensor_store_keys_.size() > 0) {
116 const auto &device_tensor_store_node = device_tensor_store_keys_[0].second;
117 MS_EXCEPTION_IF_NULL(device_tensor_store_node);
118 input_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_node.get(),
119 input_device_context->GetDeviceAddressType());
120 if (input_device_tensor_[0] == nullptr) {
121 std::string error_info =
122 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
123 ", device type:" + std::to_string(static_cast<int>(input_device_context->GetDeviceAddressType()));
124 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
125 }
126
127 output_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_node.get(),
128 output_device_context->GetDeviceAddressType());
129 if (output_device_tensor_[0] == nullptr) {
130 std::string error_info =
131 GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
132 ", device type:" + std::to_string(static_cast<int>(output_device_context->GetDeviceAddressType()));
133 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
134 }
135 } else {
136 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
137 if (data_iter == input_op_datas_.end()) {
138 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "No input data.");
139 }
140 const auto &input_data = data_iter->second[0];
141 MS_EXCEPTION_IF_NULL(input_data);
142 input_device_tensor_[0] = input_data->data_;
143
144 MS_EXCEPTION_IF_NULL(output_);
145 output_device_tensor_[0] = output_.get();
146 }
147 }
148
SendOutput(OpContext<DeviceTensor> * const context) const149 void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
150 MS_EXCEPTION_IF_NULL(context);
151 // No output.
152 if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
153 SET_OPCONTEXT_SUCCESS_RET((*context));
154 }
155
156 // Send output data.
157 for (auto &output_data : output_data_) {
158 MS_EXCEPTION_IF_NULL(output_data);
159 output_data->data_ = output_device_tensor_[0];
160 Async(output_data->op_id_, &OpActor::RunOpData, output_data.get(), context);
161 }
162
163 // Send output control.
164 if (output_control_arrows_.size() > 0) {
165 auto source_aid = const_cast<AID *>(&GetAID());
166 for (auto &output_control : output_control_arrows_) {
167 Async(output_control, &OpActor::RunOpControl, source_aid, context);
168 }
169 }
170 }
171 } // namespace runtime
172 } // namespace mindspore
173