1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <unordered_map> 24 #include <queue> 25 #include <utility> 26 #include "runtime/framework/actor/actor_common.h" 27 #include "runtime/framework/actor/debug_aware_actor.h" 28 #include "runtime/hardware/device_context.h" 29 #include "runtime/framework/device_tensor_store.h" 30 #include "runtime/framework/host_tensor_queue.h" 31 #include "base/base.h" 32 33 namespace mindspore { 34 namespace runtime { 35 using mindspore::device::DeviceContext; 36 using mindspore::device::KernelInfo; 37 using mindspore::kernel::KernelLaunchInfo; 38 39 // The data source actor is used to fetch data from data source and process them into device tensors, 40 // and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> SendMemoryAllocReq 41 // -> OnMemoryAllocFinish -> SendMemoryFreeReq -> SendOutput. 42 class DataSourceActor : public DebugAwareActor { 43 public: DataSourceActor(const std::string & name,KernelTransformType type,size_t buffer_capacity,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid)44 DataSourceActor(const std::string &name, KernelTransformType type, size_t buffer_capacity, 45 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid) 46 : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid), buffer_capacity_(buffer_capacity) {} 47 virtual ~DataSourceActor() = default; 48 49 void Init() override; 50 51 // The process entry of data processing. 52 void FetchData(OpContext<DeviceTensor> *const context); 53 54 // The memory related operation interface. SendMemoryAllocReq(OpContext<DeviceTensor> * const context)55 void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override{}; SendMemoryFreeReq(OpContext<DeviceTensor> * const context)56 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override{}; 57 // Copy data from data source to the device tensor buffer of actor after memory alloc finished. OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)58 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override{}; 59 60 protected: 61 friend class GraphScheduler; 62 63 // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching. 64 virtual void FillDataBuffer() = 0; 65 66 // Send output result of graph output to output actor. 67 virtual void SendResult(OpContext<DeviceTensor> *const context) = 0; 68 69 // Send recorder info to recorder actor, only the device queue data source actor need. SendRecorderInfo(OpContext<DeviceTensor> * const context)70 virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {} 71 72 // Send output to downstream actors to trigger computing after fetching data finished. 73 void SendOutput(OpContext<DeviceTensor> *const context); 74 75 // The buffers store the device tensors. 76 std::queue<std::vector<DeviceTensor *>> buffers_; 77 size_t buffer_capacity_; 78 79 // The output_data_ corresponds to the output_data_arrows_ one by one. 80 std::vector<OpDataUniquePtr<DeviceTensor>> output_data_; 81 }; 82 83 // The class represents that the data source is device queue. 84 class DeviceQueueDataSourceActor : public DataSourceActor { 85 public: DeviceQueueDataSourceActor(const std::string & name,size_t buffer_capacity,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid)86 DeviceQueueDataSourceActor(const std::string &name, size_t buffer_capacity, const DeviceContext *device_context, 87 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid) 88 : DataSourceActor(name, KernelTransformType::kDeviceDataSourceActor, buffer_capacity, memory_manager_aid, 89 debug_aid, recorder_aid) { 90 (void)device_contexts_.emplace_back(device_context); 91 } 92 ~DeviceQueueDataSourceActor() override = default; 93 94 void Init() override; 95 96 void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override; 97 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; 98 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; 99 100 void SendDebugReq(OpContext<DeviceTensor> *const context) override; 101 void OnDebugFinish(OpContext<DeviceTensor> *const context) override; 102 103 protected: 104 void FillDataBuffer() override; 105 void SendResult(OpContext<DeviceTensor> *const context) override; 106 void SendRecorderInfo(OpContext<DeviceTensor> *const context) override; 107 108 private: 109 friend class GraphScheduler; 110 111 // Input data kernel(for example GetNext) fetches data from device queue. 112 CNodePtr data_kernel_{nullptr}; 113 KernelInfo *kernel_info_{nullptr}; 114 115 // The kernel launch info is fetched by the device tensors. 116 KernelLaunchInfo launch_info_; 117 }; 118 119 // The class represents that the data source is host queue. 120 class HostQueueDataSourceActor : public DataSourceActor { 121 public: HostQueueDataSourceActor(std::string name,size_t buffer_capacity,const AID memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,HostTensorQueuePtr host_queue)122 HostQueueDataSourceActor(std::string name, size_t buffer_capacity, const AID memory_manager_aid, const AID *debug_aid, 123 const AID *recorder_aid, HostTensorQueuePtr host_queue) 124 : DataSourceActor(name, KernelTransformType::kHostDataSourceActor, buffer_capacity, memory_manager_aid, debug_aid, 125 recorder_aid), 126 host_queue_(host_queue) {} 127 ~HostQueueDataSourceActor() override = default; 128 129 void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override; 130 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; 131 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; 132 133 size_t FetchNodePosition(const AnfNodePtr &node) const override; 134 AnfNodePtr FetchNode(size_t node_position) const; data_nodes()135 const std::vector<AnfNodePtr> &data_nodes() const { return data_nodes_; } 136 137 protected: 138 void FillDataBuffer() override; 139 void SendResult(OpContext<DeviceTensor> *const context) override; 140 141 private: 142 friend class GraphScheduler; 143 144 // Judge all the data_nodes_ is from the same device. 145 bool IsSameDeviceType() const; 146 147 HostTensorQueuePtr host_queue_; 148 // Input data nodes fetch data from host queue. 149 std::vector<AnfNodePtr> data_nodes_; 150 151 // The location of the data node in the data source actor. 152 std::unordered_map<AnfNodePtr, size_t> data_node_position_map_; 153 }; 154 155 using DataSourceActorPtr = std::shared_ptr<DataSourceActor>; 156 using DeviceQueueDSActorPtr = std::shared_ptr<DeviceQueueDataSourceActor>; 157 using HostQueueDSActorPtr = std::shared_ptr<HostQueueDataSourceActor>; 158 159 } // namespace runtime 160 } // namespace mindspore 161 162 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_ 163