• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <map>
23 #include <memory>
24 #include <queue>
25 #include <utility>
26 #include "utils/hash_map.h"
27 #include "runtime/graph_scheduler/actor/actor_common.h"
28 #include "runtime/graph_scheduler/actor/debug_aware_actor.h"
29 #include "runtime/hardware/device_context.h"
30 #include "runtime/graph_scheduler/device_tensor_store.h"
31 #include "runtime/graph_scheduler/host_tensor_queue.h"
32 #include "base/base.h"
33 
34 namespace mindspore {
35 namespace runtime {
36 using mindspore::device::DeviceContext;
37 using mindspore::device::KernelInfo;
38 using mindspore::kernel::KernelLaunchAddr;
39 
40 // The data source actor is used to fetch data from data source and process them into device tensors,
41 // and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> SendMemoryAllocReq
42 // -> OnMemoryAllocFinish -> SendMemoryFreeReq -> SendOutput.
43 class DataSourceActor : public DebugAwareActor {
44  public:
DataSourceActor(const std::string & name,KernelTransformType type,size_t buffer_capacity,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid)45   DataSourceActor(const std::string &name, KernelTransformType type, size_t buffer_capacity,
46                   const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
47       : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid, nullptr),
48         buffer_capacity_(buffer_capacity) {}
49   ~DataSourceActor() override = default;
50 
ReleaseData()51   virtual void ReleaseData() {}
52 
53  protected:
54   friend class GraphScheduler;
55   friend class ControlNodeScheduler;
56   friend class AnyTypeGraphScheduler;
57 
58   void Init() override;
59 
Run(OpContext<DeviceTensor> * const context)60   void Run(OpContext<DeviceTensor> *const context) override { FetchData(context); }
61 
62   // The process entry of data processing.
63   void FetchData(OpContext<DeviceTensor> *const context);
64 
65   // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching.
66   virtual void FillDataBuffer() = 0;
67 
68   void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
69                         const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override;
70 
71   // The buffers store the device tensors.
72   std::queue<std::vector<DeviceTensor *>> buffers_;
73   size_t buffer_capacity_;
74 };
75 
76 // The class represents that the data source is device queue.
77 class DeviceQueueDataSourceActor : public DataSourceActor {
78  public:
DeviceQueueDataSourceActor(const std::string & name,size_t buffer_capacity,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid)79   DeviceQueueDataSourceActor(const std::string &name, size_t buffer_capacity, const DeviceContext *device_context,
80                              const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
81       : DataSourceActor(name, KernelTransformType::kDeviceDataSourceActor, buffer_capacity, memory_manager_aid,
82                         debug_aid, recorder_aid) {
83     (void)device_contexts_.emplace_back(device_context);
84   }
85   ~DeviceQueueDataSourceActor() override = default;
86 
87   // The memory related operation interface.
88   void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
89   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
90   // Copy data from data source to the device tensor buffer of actor after memory alloc finished.
91   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
92 
93   void SendDebugReq(OpContext<DeviceTensor> *const context) override;
94 
data_kernel()95   const CNodePtr &data_kernel() const { return data_kernel_; }
96 
97  protected:
98   void Init() override;
99   void FillDataBuffer() override;
100   void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;
101 
102  private:
103   friend class GraphScheduler;
104   friend class ControlNodeScheduler;
105 
106   // Input data kernel(for example GetNext) fetches data from device queue.
107   CNodePtr data_kernel_{nullptr};
108   KernelInfo *kernel_info_{nullptr};
109 
110   bool is_dynamic_shape_;
111 
112   // The kernel mem info is needed for recording info.
113   KernelLaunchAddr mem_info_;
114 
115   // The kernel tensors for resize and launch.
116   std::vector<KernelTensor *> output_kernel_tensors_;
117 
118   // The stream resource of the Actor to launch kernel.
119   void *stream_{nullptr};
120 };
121 
122 // The class represents that the data source is host queue.
123 class HostQueueDataSourceActor : public DataSourceActor {
124  public:
HostQueueDataSourceActor(const std::string & name,size_t buffer_capacity,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,const HostTensorQueuePtr & host_queue)125   HostQueueDataSourceActor(const std::string &name, size_t buffer_capacity, const AID &memory_manager_aid,
126                            const AID *debug_aid, const AID *recorder_aid, const HostTensorQueuePtr &host_queue)
127       : DataSourceActor(name, KernelTransformType::kHostDataSourceActor, buffer_capacity, memory_manager_aid, debug_aid,
128                         recorder_aid),
129         host_queue_(host_queue) {}
130   ~HostQueueDataSourceActor() override = default;
131 
132   // The memory related operation interface.
133   void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
134   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
135   // Copy data from data source to the device tensor buffer of actor after memory alloc finished.
136   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
137 
138   size_t FetchNodePosition(const KernelWithIndex &node) const override;
139   KernelWithIndex FetchNode(size_t node_position) const;
data_nodes()140   const std::vector<KernelWithIndex> &data_nodes() const { return data_node_with_indexs_; }
141 
142   void ReleaseData() override;
143 
144  protected:
145   void FillDataBuffer() override;
146 
147  private:
148   friend class GraphScheduler;
149   friend class ControlNodeScheduler;
150 
151   // Judge all the data_nodes_ is from the same device.
152   bool IsSameDeviceType() const;
153 
154   HostTensorQueuePtr host_queue_;
155   // Input data nodes fetch data from host queue.
156   std::vector<KernelWithIndex> data_node_with_indexs_;
157 
158   // The location of the data node in the data source actor.
159   std::map<KernelWithIndex, size_t> data_node_position_map_;
160 };
161 
162 using DataSourceActorPtr = std::shared_ptr<DataSourceActor>;
163 using DeviceQueueDSActorPtr = std::shared_ptr<DeviceQueueDataSourceActor>;
164 using HostQueueDSActorPtr = std::shared_ptr<HostQueueDataSourceActor>;
165 
166 }  // namespace runtime
167 }  // namespace mindspore
168 
169 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
170