• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/copy_actor.h"
18 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
19 #include "mindrt/include/async/async.h"
20 #include "utils/log_adapter.h"
21 #include "include/backend/mem_reuse/mem_tracker.h"
22 
23 namespace mindspore {
24 namespace runtime {
25 const size_t kInputDeviceContextIndex = 0;
26 const size_t kOutputDeviceContextIndex = 1;
27 
Init()28 void CopyActor::Init() {
29   // Check device contexts number.
30   if (device_contexts_.size() != device::kDeviceContextsNumTwo) {
31     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
32   }
33 
34   const size_t kDeviceTensorNum = 1;
35   input_device_tensor_.resize(kDeviceTensorNum);
36   output_device_tensor_.resize(kDeviceTensorNum);
37 
38   // Check output data index.
39   for (auto &data_arrow : output_data_arrows_) {
40     MS_EXCEPTION_IF_NULL(data_arrow);
41     if (IntToSize(data_arrow->from_output_index_) != 0) {
42       MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
43     }
44   }
45 
46   InitOutputData();
47 }
48 
Run(OpContext<DeviceTensor> * const context)49 void CopyActor::Run(OpContext<DeviceTensor> *const context) {
50   MS_EXCEPTION_IF_NULL(context);
51   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), GetAID().Name(), "");
52   FetchDeviceTensor(context);
53   SendMemoryAllocReq(context);
54 }
55 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)56 void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
57   if (ActorDispatcher::is_memory_allocation_sync()) {
58     ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_,
59                               device_contexts_[kOutputDeviceContextIndex], context, GetAID());
60     OnMemoryAllocFinish(context);
61   } else {
62     ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_,
63                           device_contexts_[kOutputDeviceContextIndex], context, GetAID());
64   }
65 }
66 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)67 void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
68   if (ActorDispatcher::is_memory_free_sync()) {
69     ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_,
70                               device_contexts_[kInputDeviceContextIndex], context, GetAID());
71     ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_,
72                               device_contexts_[kOutputDeviceContextIndex], context, GetAID());
73   } else {
74     ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_,
75                           device_contexts_[kInputDeviceContextIndex], context, GetAID());
76     ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_,
77                           device_contexts_[kOutputDeviceContextIndex], context, GetAID());
78   }
79 }
80 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)81 void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
82   MS_EXCEPTION_IF_NULL(context);
83   MS_EXCEPTION_IF_NULL(output_device_tensor_[0]);
84   MS_EXCEPTION_IF_NULL(input_device_tensor_[0]);
85   if (IsRunningFailed(context)) {
86     return;
87   }
88 
89   if (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize()) {
90     MS_LOG(WARNING) << GetAID().Name() << " copy size is not equal, input size:" << input_device_tensor_[0]->GetSize()
91                     << ", output size:" << output_device_tensor_[0]->GetSize();
92   }
93 
94   {
95     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kCopyData, GetAID().Name());
96     if (!Copy(output_device_tensor_[0], input_device_tensor_[0])) {
97       std::string error_info = "Copy device tensor failed: " + GetAID().Name();
98       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
99     }
100     output_device_tensor_[0]->kernel_tensor()->SetType(input_device_tensor_[0]->kernel_tensor()->GetType());
101     output_device_tensor_[0]->kernel_tensor()->SetShape(input_device_tensor_[0]->kernel_tensor()->GetShape());
102     output_device_tensor_[0]->set_user_data(input_device_tensor_[0]->user_data());
103     MS_LOG(DEBUG) << "Set user data:" << input_device_tensor_[0]->user_data()
104                   << " shape:" << input_device_tensor_[0]->kernel_tensor()->GetShape()->ToString()
105                   << " from device tensor:" << input_device_tensor_[0]
106                   << " to device address:" << output_device_tensor_[0];
107     output_device_tensor_[0]->set_need_sync_user_data(input_device_tensor_[0]->need_sync_user_data());
108   }
109 
110   PostRun(context);
111 }
112 
FetchDeviceTensor(OpContext<DeviceTensor> * const context)113 void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
114   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
115   MS_EXCEPTION_IF_NULL(context);
116   const auto &input_device_context = device_contexts_[kInputDeviceContextIndex];
117   const auto &output_device_context = device_contexts_[kOutputDeviceContextIndex];
118   MS_EXCEPTION_IF_NULL(input_device_context);
119   MS_EXCEPTION_IF_NULL(output_device_context);
120 
121   if (device_tensor_store_keys_.size() > 0) {
122     const auto &device_tensor_store_node = device_tensor_store_keys_[0].second;
123     MS_EXCEPTION_IF_NULL(device_tensor_store_node);
124     input_device_tensor_[0] = DeviceTensorStore::GetInstance()
125                                 .Fetch(device_tensor_store_node.get(), input_device_context->GetDeviceType())
126                                 .get();
127     if (input_device_tensor_[0] == nullptr) {
128       std::string error_info =
129         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
130         ", device type:" + std::to_string(static_cast<int>(input_device_context->GetDeviceType()));
131       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
132     }
133 
134     output_device_tensor_[0] = DeviceTensorStore::GetInstance()
135                                  .Fetch(device_tensor_store_node.get(), output_device_context->GetDeviceType())
136                                  .get();
137     if (output_device_tensor_[0] == nullptr) {
138       std::string error_info =
139         GetAID().Name() + " get device tensor store failed: " + device_tensor_store_node->fullname_with_scope() +
140         ", device type:" + std::to_string(static_cast<int>(output_device_context->GetDeviceType()));
141       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
142     }
143   } else {
144     const auto &data_iter = input_op_datas_.find(context->sequential_num_);
145     if (data_iter == input_op_datas_.end()) {
146       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "No input data.");
147     }
148     const auto &input_data = data_iter->second[0];
149     MS_EXCEPTION_IF_NULL(input_data);
150     input_device_tensor_[0] = input_data->data_;
151 
152     MS_EXCEPTION_IF_NULL(output_);
153     output_device_tensor_[0] = output_.get();
154   }
155 
156   if (!WaitRuntimePipelineFinish(context)) {
157     MS_LOG(INFO) << "Run failed and early stop.";
158     return;
159   }
160   if (is_need_update_output_size_ && (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize())) {
161     MS_LOG(DEBUG) << GetAID().Name() << " update output size from " << output_device_tensor_[0]->GetSize() << " to "
162                   << input_device_tensor_[0]->GetSize();
163     output_device_tensor_[0]->SetSize(input_device_tensor_[0]->GetSize());
164     const auto &output_kernel_tensor = output_device_tensor_[0]->kernel_tensor();
165     const auto &input_kernel_tensor = input_device_tensor_[0]->kernel_tensor();
166     MS_EXCEPTION_IF_NULL(output_kernel_tensor);
167     MS_EXCEPTION_IF_NULL(input_kernel_tensor);
168     output_kernel_tensor->SetType(input_kernel_tensor->GetType()->Clone());
169     output_kernel_tensor->SetShape(input_kernel_tensor->GetShape()->Clone());
170   }
171 }
172 
UpdateOutputData(OpData<DeviceTensor> * const output_data,const DataArrowPtr &,const AnfNodePtr &,OpContext<DeviceTensor> * const)173 void CopyActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &, const AnfNodePtr &,
174                                  OpContext<DeviceTensor> *const) {
175   MS_EXCEPTION_IF_NULL(output_data);
176   output_data->data_ = output_device_tensor_[0];
177 }
178 }  // namespace runtime
179 }  // namespace mindspore
180