1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_ 19 20 #include <algorithm> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <map> 25 #include <utility> 26 #include "utils/hash_map.h" 27 #include "runtime/graph_scheduler/actor/actor_common.h" 28 #include "runtime/graph_scheduler/actor/debug_aware_actor.h" 29 #include "runtime/graph_scheduler/device_tensor_store.h" 30 #include "runtime/graph_scheduler/control_node_parser.h" 31 32 namespace mindspore { 33 namespace runtime { 34 // The loop count actor is used to receive the control of tail kernel actor to represent the end of one step 35 // and decide whether to loop execution by loop count. 36 class LoopCountActor : public DebugAwareActor { 37 public: LoopCountActor(const std::string & name,const std::string & graph_name,size_t loop_count,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,const AID * profiler_aid,GraphExecutionStrategy strategy,const std::vector<DeviceContext * > & device_contexts,const bool is_need_sync_stream)38 LoopCountActor(const std::string &name, const std::string &graph_name, size_t loop_count, 39 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid, const AID *profiler_aid, 40 GraphExecutionStrategy strategy, const std::vector<DeviceContext *> &device_contexts, 41 const bool is_need_sync_stream) 42 : DebugAwareActor(name, KernelTransformType::kLoopCountActor, recorder_aid, memory_manager_aid, debug_aid, 43 profiler_aid), 44 graph_name_(graph_name), 45 loop_count_(loop_count), 46 current_count_(0), 47 total_running_count_(0), 48 strategy_(strategy), 49 is_need_sync_stream_(is_need_sync_stream) { 50 (void)std::transform( 51 device_contexts.begin(), device_contexts.end(), std::back_inserter(device_contexts_), 52 [](DeviceContext *device_context) { return static_cast<const DeviceContext *>(device_context); }); 53 } 54 55 ~LoopCountActor() override = default; 56 57 // The callback waits for the memory manager actor to finish all the message processing. 58 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; 59 60 // The debug related operation interface. 61 void SendDebugReq(OpContext<DeviceTensor> *const context) override; 62 void SendProfilerReq(OpContext<DeviceTensor> *const context); 63 64 // Get the member. loop_count()65 size_t loop_count() const { return loop_count_; } data_prepare_aid()66 const AID &data_prepare_aid() const { return data_prepare_aid_; } entrance_aids()67 const std::vector<AID> &entrance_aids() const { return entrance_aids_; } 68 69 protected: 70 void Run(OpContext<DeviceTensor> *const context) override; 71 void SendOutput(OpContext<DeviceTensor> *const context) override; 72 73 private: 74 friend class GraphScheduler; 75 friend class ControlNodeScheduler; 76 77 void IncreaseLoopCount(OpContext<DeviceTensor> *const context); 78 79 // Graph name of GraphCompilerInfo. For example, kernel_graph_0-3. 80 std::string graph_name_; 81 82 // The loop count is constant, the current count is increased after each step running finished. 83 size_t loop_count_; 84 size_t current_count_; 85 // The total running count represents the toal step running count. 86 size_t total_running_count_; 87 88 // The actors which need be handled separately by loop count actor. 89 AID data_prepare_aid_; 90 std::vector<AID> entrance_aids_; 91 92 // The execution strategy for executing actor. 93 // In pipeline mode, sync stream for every step. 94 GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline}; 95 96 // Only need sync stream in DR scenarios. 97 bool is_need_sync_stream_{true}; 98 }; 99 100 using LoopCountActorPtr = std::shared_ptr<LoopCountActor>; 101 } // namespace runtime 102 } // namespace mindspore 103 104 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_ 105