• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
19 
20 #include <algorithm>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <map>
25 #include <utility>
26 #include "utils/hash_map.h"
27 #include "runtime/graph_scheduler/actor/actor_common.h"
28 #include "runtime/graph_scheduler/actor/debug_aware_actor.h"
29 #include "runtime/graph_scheduler/device_tensor_store.h"
30 #include "runtime/graph_scheduler/control_node_parser.h"
31 
32 namespace mindspore {
33 namespace runtime {
34 // The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
35 // and decide whether to loop execution by loop count.
36 class LoopCountActor : public DebugAwareActor {
37  public:
LoopCountActor(const std::string & name,const std::string & graph_name,size_t loop_count,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,const AID * profiler_aid,GraphExecutionStrategy strategy,const std::vector<DeviceContext * > & device_contexts,const bool is_need_sync_stream)38   LoopCountActor(const std::string &name, const std::string &graph_name, size_t loop_count,
39                  const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid, const AID *profiler_aid,
40                  GraphExecutionStrategy strategy, const std::vector<DeviceContext *> &device_contexts,
41                  const bool is_need_sync_stream)
42       : DebugAwareActor(name, KernelTransformType::kLoopCountActor, recorder_aid, memory_manager_aid, debug_aid,
43                         profiler_aid),
44         graph_name_(graph_name),
45         loop_count_(loop_count),
46         current_count_(0),
47         total_running_count_(0),
48         strategy_(strategy),
49         is_need_sync_stream_(is_need_sync_stream) {
50     (void)std::transform(
51       device_contexts.begin(), device_contexts.end(), std::back_inserter(device_contexts_),
52       [](DeviceContext *device_context) { return static_cast<const DeviceContext *>(device_context); });
53   }
54 
55   ~LoopCountActor() override = default;
56 
57   // The callback waits for the memory manager actor to finish all the message processing.
58   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
59 
60   // The debug related operation interface.
61   void SendDebugReq(OpContext<DeviceTensor> *const context) override;
62   void SendProfilerReq(OpContext<DeviceTensor> *const context);
63 
64   // Get the member.
loop_count()65   size_t loop_count() const { return loop_count_; }
data_prepare_aid()66   const AID &data_prepare_aid() const { return data_prepare_aid_; }
entrance_aids()67   const std::vector<AID> &entrance_aids() const { return entrance_aids_; }
68 
69  protected:
70   void Run(OpContext<DeviceTensor> *const context) override;
71   void SendOutput(OpContext<DeviceTensor> *const context) override;
72 
73  private:
74   friend class GraphScheduler;
75   friend class ControlNodeScheduler;
76 
77   void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
78 
79   // Graph name of GraphCompilerInfo. For example, kernel_graph_0-3.
80   std::string graph_name_;
81 
82   // The loop count is constant, the current count is increased after each step running finished.
83   size_t loop_count_;
84   size_t current_count_;
85   // The total running count represents the toal step running count.
86   size_t total_running_count_;
87 
88   // The actors which need be handled separately by loop count actor.
89   AID data_prepare_aid_;
90   std::vector<AID> entrance_aids_;
91 
92   // The execution strategy for executing actor.
93   // In pipeline mode,  sync stream for every step.
94   GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
95 
96   // Only need sync stream in DR scenarios.
97   bool is_need_sync_stream_{true};
98 };
99 
100 using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
101 }  // namespace runtime
102 }  // namespace mindspore
103 
104 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
105