1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/actor/loop_count_actor.h"
18 #include "runtime/framework/actor/data_prepare_actor.h"
19 #include "runtime/framework/actor/output_actor.h"
20 #include "runtime/framework/actor/memory_manager_actor.h"
21 #include "runtime/framework/actor/recorder_actor.h"
22 #include "runtime/framework/actor/debug_actor.h"
23 #include "mindrt/include/async/async.h"
24 #include "utils/log_adapter.h"
25
26 namespace mindspore {
27 namespace runtime {
RunOpControl(AID * const input_control,OpContext<DeviceTensor> * const context)28 void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
29 MS_EXCEPTION_IF_NULL(context);
30 auto sequential_num = context->sequential_num_;
31 (void)input_op_controls_[sequential_num].emplace_back(input_control);
32 if (CheckRunningCondition(context)) {
33 // Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
34 // LoopCountActor exits, because other processors which are not in actor also will process device tensor.
35 Async(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID());
36 }
37 }
38
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)39 void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
40 MS_EXCEPTION_IF_NULL(context);
41 IncreaseLoopCount(context);
42 }
43
IncreaseLoopCount(OpContext<DeviceTensor> * const context)44 void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
45 MS_EXCEPTION_IF_NULL(context);
46 EraseInput(context);
47
48 total_running_count_++;
49 current_count_++;
50 MS_LOG(INFO) << "Loop count actor(" << GetAID().Name() << ") running, loop count: " << loop_count_
51 << ", current count: " << current_count_ << ", total running count: " << total_running_count_;
52
53 // Debug actor is blocked, must wait debug actor callback message to process continue.
54 if (debug_aid_ != nullptr) {
55 SendDebugReq(context);
56 return;
57 }
58
59 SendOutput(context);
60 }
61
SendDebugReq(OpContext<DeviceTensor> * const context)62 void LoopCountActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
63 Async(*debug_aid_, &DebugActor::DebugOnStepEnd, context, &GetAID());
64 }
65
OnDebugFinish(OpContext<DeviceTensor> * const context)66 void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
67 MS_EXCEPTION_IF_NULL(context);
68 SendOutput(context);
69 }
70
SendOutput(OpContext<DeviceTensor> * const context)71 void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
72 // Send recorder info.
73 if (recorder_aid_ != nullptr) {
74 Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context);
75 }
76
77 // Send loop count to output actor.
78 Async(output_aid_, &OutputActor::CollectLoopCount, current_count_, context);
79
80 // The LoopCountActor exits.
81 if (current_count_ == loop_count_) {
82 current_count_ = 0;
83 return;
84 }
85
86 // Send to DataPrepareActor to trigger next step running.
87 std::vector<std::vector<TensorPtr>> input_tensors;
88 Async(data_prepare_aid_, &DataPrepareActor::PrepareData, input_tensors, context);
89 }
90 } // namespace runtime
91 } // namespace mindspore
92