• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <unordered_map>
24 #include <queue>
25 #include <utility>
26 #include "runtime/framework/actor/actor_common.h"
27 #include "runtime/framework/actor/debug_aware_actor.h"
28 #include "runtime/hardware/device_context.h"
29 #include "runtime/framework/device_tensor_store.h"
30 #include "runtime/framework/host_tensor_queue.h"
31 #include "base/base.h"
32 
33 namespace mindspore {
34 namespace runtime {
35 using mindspore::device::DeviceContext;
36 using mindspore::device::KernelInfo;
37 using mindspore::kernel::KernelLaunchInfo;
38 
39 // The data source actor is used to fetch data from data source and process them into device tensors,
40 // and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> SendMemoryAllocReq
41 // -> OnMemoryAllocFinish -> SendMemoryFreeReq -> SendOutput.
42 class DataSourceActor : public DebugAwareActor {
43  public:
DataSourceActor(const std::string & name,KernelTransformType type,size_t buffer_capacity,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid)44   DataSourceActor(const std::string &name, KernelTransformType type, size_t buffer_capacity,
45                   const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
46       : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid), buffer_capacity_(buffer_capacity) {}
47   virtual ~DataSourceActor() = default;
48 
49   void Init() override;
50 
51   // The process entry of data processing.
52   void FetchData(OpContext<DeviceTensor> *const context);
53 
54   // The memory related operation interface.
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)55   void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override{};
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)56   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override{};
57   // Copy data from data source to the device tensor buffer of actor after memory alloc finished.
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)58   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override{};
59 
60  protected:
61   friend class GraphScheduler;
62 
63   // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching.
64   virtual void FillDataBuffer() = 0;
65 
66   // Send output result of graph output to output actor.
67   virtual void SendResult(OpContext<DeviceTensor> *const context) = 0;
68 
69   // Send recorder info to recorder actor, only the device queue data source actor need.
SendRecorderInfo(OpContext<DeviceTensor> * const context)70   virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {}
71 
72   // Send output to downstream actors to trigger computing after fetching data finished.
73   void SendOutput(OpContext<DeviceTensor> *const context);
74 
75   // The buffers store the device tensors.
76   std::queue<std::vector<DeviceTensor *>> buffers_;
77   size_t buffer_capacity_;
78 
79   //  The output_data_ corresponds to the output_data_arrows_ one by one.
80   std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;
81 };
82 
83 // The class represents that the data source is device queue.
84 class DeviceQueueDataSourceActor : public DataSourceActor {
85  public:
DeviceQueueDataSourceActor(const std::string & name,size_t buffer_capacity,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid)86   DeviceQueueDataSourceActor(const std::string &name, size_t buffer_capacity, const DeviceContext *device_context,
87                              const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
88       : DataSourceActor(name, KernelTransformType::kDeviceDataSourceActor, buffer_capacity, memory_manager_aid,
89                         debug_aid, recorder_aid) {
90     (void)device_contexts_.emplace_back(device_context);
91   }
92   ~DeviceQueueDataSourceActor() override = default;
93 
94   void Init() override;
95 
96   void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
97   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
98   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
99 
100   void SendDebugReq(OpContext<DeviceTensor> *const context) override;
101   void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
102 
103  protected:
104   void FillDataBuffer() override;
105   void SendResult(OpContext<DeviceTensor> *const context) override;
106   void SendRecorderInfo(OpContext<DeviceTensor> *const context) override;
107 
108  private:
109   friend class GraphScheduler;
110 
111   // Input data kernel(for example GetNext) fetches data from device queue.
112   CNodePtr data_kernel_{nullptr};
113   KernelInfo *kernel_info_{nullptr};
114 
115   // The kernel launch info is fetched by the device tensors.
116   KernelLaunchInfo launch_info_;
117 };
118 
119 // The class represents that the data source is host queue.
120 class HostQueueDataSourceActor : public DataSourceActor {
121  public:
HostQueueDataSourceActor(std::string name,size_t buffer_capacity,const AID memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,HostTensorQueuePtr host_queue)122   HostQueueDataSourceActor(std::string name, size_t buffer_capacity, const AID memory_manager_aid, const AID *debug_aid,
123                            const AID *recorder_aid, HostTensorQueuePtr host_queue)
124       : DataSourceActor(name, KernelTransformType::kHostDataSourceActor, buffer_capacity, memory_manager_aid, debug_aid,
125                         recorder_aid),
126         host_queue_(host_queue) {}
127   ~HostQueueDataSourceActor() override = default;
128 
129   void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
130   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
131   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
132 
133   size_t FetchNodePosition(const AnfNodePtr &node) const override;
134   AnfNodePtr FetchNode(size_t node_position) const;
data_nodes()135   const std::vector<AnfNodePtr> &data_nodes() const { return data_nodes_; }
136 
137  protected:
138   void FillDataBuffer() override;
139   void SendResult(OpContext<DeviceTensor> *const context) override;
140 
141  private:
142   friend class GraphScheduler;
143 
144   // Judge all the data_nodes_ is from the same device.
145   bool IsSameDeviceType() const;
146 
147   HostTensorQueuePtr host_queue_;
148   // Input data nodes fetch data from host queue.
149   std::vector<AnfNodePtr> data_nodes_;
150 
151   // The location of the data node in the data source actor.
152   std::unordered_map<AnfNodePtr, size_t> data_node_position_map_;
153 };
154 
155 using DataSourceActorPtr = std::shared_ptr<DataSourceActor>;
156 using DeviceQueueDSActorPtr = std::shared_ptr<DeviceQueueDataSourceActor>;
157 using HostQueueDSActorPtr = std::shared_ptr<HostQueueDataSourceActor>;
158 
159 }  // namespace runtime
160 }  // namespace mindspore
161 
162 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
163