• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/framework/actor/copy_actor.h"
18 #include "runtime/framework/actor/memory_manager_actor.h"
19 #include "mindrt/include/async/async.h"
20 #include "utils/log_adapter.h"
21 
22 namespace mindspore {
23 namespace runtime {
24 const size_t kInputDeviceContextIndex = 0;
25 const size_t kOutputDeviceContextIndex = 1;
26 
Init()27 void CopyActor::Init() {
28   // Check device contexts number.
29   if (device_contexts_.size() != device::kDeviceContextsNumTwo) {
30     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
31   }
32 
33   const size_t kDeviceTensorNum = 1;
34   input_device_tensor_.resize(kDeviceTensorNum);
35   output_device_tensor_.resize(kDeviceTensorNum);
36 
37   // Init output data.
38   for (auto &data_arrow : output_data_arrows_) {
39     MS_EXCEPTION_IF_NULL(data_arrow);
40     if (IntToSize(data_arrow->from_output_index_) != 0) {
41       MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
42     }
43     auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
44     (void)output_data_.emplace_back(std::move(data));
45   }
46 }
47 
RunOpData(OpData<DeviceTensor> * const input_data,OpContext<DeviceTensor> * const context)48 void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
49   MS_EXCEPTION_IF_NULL(context);
50   auto &sequential_num = context->sequential_num_;
51   (void)input_op_datas_[sequential_num].emplace_back(input_data);
52   // When all the inputs are collected, then allocate memory and callback copy.
53   if (CheckRunningCondition(context)) {
54     FetchDeviceTensor(context);
55     SendMemoryAllocReq(context);
56   }
57 }
58 
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)59 void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
60   MS_EXCEPTION_IF_NULL(context);
61   auto &sequential_num = context->sequential_num_;
62   (void)input_op_controls_[sequential_num].emplace_back(input_control);
63   // When all the inputs are collected, then allocate memory and callback copy.
64   if (CheckRunningCondition(context)) {
65     FetchDeviceTensor(context);
66     SendMemoryAllocReq(context);
67   }
68 }
69 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)70 void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
71   Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_,
72         device_contexts_[kOutputDeviceContextIndex], context, GetAID());
73 }
74 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)75 void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
76   Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_,
77         device_contexts_[kInputDeviceContextIndex], context);
78   Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_,
79         device_contexts_[kOutputDeviceContextIndex], context);
80 }
81 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)82 void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
83   MS_EXCEPTION_IF_NULL(context);
84   MS_EXCEPTION_IF_NULL(output_device_tensor_[0]);
85   MS_EXCEPTION_IF_NULL(input_device_tensor_[0]);
86 
87   if (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize()) {
88     MS_LOG(WARNING) << GetAID().Name() << " copy size is not equal, input size:" << input_device_tensor_[0]->GetSize()
89                     << ", output size:" << output_device_tensor_[0]->GetSize();
90   }
91 
92   if (!Copy(output_device_tensor_[0], input_device_tensor_[0])) {
93     std::string error_info = "Copy device tensor failed: " + GetAID().Name();
94     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
95   }
96 
97   // The input is invalid and needs to be erased when finish copy.
98   EraseInput(context);
99 
100   // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
101   // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of the
102   // current actor is in front of SendMemoryAllocReq of the next actor.  One is to reuse the memory more fully, the
103   // other is to ensure the execution order and avoid the illegal memory timing problem.
104   SendMemoryFreeReq(context);
105   SendOutput(context);
106 }
107 
FetchDeviceTensor(OpContext<DeviceTensor> * const context)108 void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
109   MS_EXCEPTION_IF_NULL(context);
110   const auto &input_device_context = device_contexts_[kInputDeviceContextIndex];
111   const auto &output_device_context = device_contexts_[kOutputDeviceContextIndex];
112   MS_EXCEPTION_IF_NULL(input_device_context);
113   MS_EXCEPTION_IF_NULL(output_device_context);
114 
115   if (device_tensor_store_keys_.size() > 0) {
116     const auto &device_tensor_store_node = device_tensor_store_keys_[0].second;
117     MS_EXCEPTION_IF_NULL(device_tensor_store_node);
118     input_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_node.get(),
119                                                                      input_device_context->GetDeviceAddressType());
120     if (input_device_tensor_[0] == nullptr) {
121       std::string error_info =
122         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
123         ", device type:" + std::to_string(static_cast<int>(input_device_context->GetDeviceAddressType()));
124       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
125     }
126 
127     output_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_node.get(),
128                                                                       output_device_context->GetDeviceAddressType());
129     if (output_device_tensor_[0] == nullptr) {
130       std::string error_info =
131         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
132         ", device type:" + std::to_string(static_cast<int>(output_device_context->GetDeviceAddressType()));
133       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
134     }
135   } else {
136     const auto &data_iter = input_op_datas_.find(context->sequential_num_);
137     if (data_iter == input_op_datas_.end()) {
138       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "No input data.");
139     }
140     const auto &input_data = data_iter->second[0];
141     MS_EXCEPTION_IF_NULL(input_data);
142     input_device_tensor_[0] = input_data->data_;
143 
144     MS_EXCEPTION_IF_NULL(output_);
145     output_device_tensor_[0] = output_.get();
146   }
147 }
148 
SendOutput(OpContext<DeviceTensor> * const context) const149 void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
150   MS_EXCEPTION_IF_NULL(context);
151   // No output.
152   if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
153     SET_OPCONTEXT_SUCCESS_RET((*context));
154   }
155 
156   // Send output data.
157   for (auto &output_data : output_data_) {
158     MS_EXCEPTION_IF_NULL(output_data);
159     output_data->data_ = output_device_tensor_[0];
160     Async(output_data->op_id_, &OpActor::RunOpData, output_data.get(), context);
161   }
162 
163   // Send output control.
164   if (output_control_arrows_.size() > 0) {
165     auto source_aid = const_cast<AID *>(&GetAID());
166     for (auto &output_control : output_control_arrows_) {
167       Async(output_control, &OpActor::RunOpControl, source_aid, context);
168     }
169   }
170 }
171 }  // namespace runtime
172 }  // namespace mindspore
173