• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
19 
20 #include <stdint.h>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 namespace mindspore::ge::model_runner {
27 enum TaskInfoType {
28   CCE = 0,
29   TBE,
30   AICPU,
31   LABEL_SET,
32   LABEL_SWITCH,
33   LABEL_GOTO,
34   EVENT_RECORD,
35   EVENT_WAIT,
36   FUSION_START,
37   FUSION_END,
38   HCCL,
39   PROFILER_TRACE,
40   MEMCPY_ASYNC,
41   STREAM_SWITCH,
42   STREAM_ACTIVE,
43   // Insert new task type here
44   REVSERVED = 23
45 };
46 
47 class TaskInfo {
48  public:
~TaskInfo()49   virtual ~TaskInfo() {}
stream_id()50   uint32_t stream_id() const { return stream_id_; }
type()51   TaskInfoType type() const { return type_; }
op_name()52   std::string op_name() const { return op_name_; }
dump_flag()53   bool dump_flag() const { return dump_flag_; }
54 
55  protected:
TaskInfo(const std::string & op_name,uint32_t stream_id,TaskInfoType type,bool dump_flag)56   TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
57       : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
58 
59  private:
60   std::string op_name_;
61   uint32_t stream_id_;
62   TaskInfoType type_;
63   bool dump_flag_;
64 };
65 
66 class TbeTaskInfo : public TaskInfo {
67  public:
TbeTaskInfo(const std::string & op_name,uint32_t stream_id,const std::string & stub_func,uint32_t block_dim,const std::vector<uint8_t> & args,uint32_t args_size,const std::vector<uint8_t> & sm_desc,void * binary,uint32_t binary_size,const std::vector<uint8_t> & meta_data,const std::vector<void * > & input_data_addrs,const std::vector<void * > & output_data_addrs,const std::vector<void * > & workspace_addrs,bool dump_flag)68   TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
69               const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
70               uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
71               const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
72       : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
73         stub_func_(stub_func),
74         block_dim_(block_dim),
75         args_(args),
76         args_size_(args_size),
77         sm_desc_(sm_desc),
78         binary_(binary),
79         binary_size_(binary_size),
80         meta_data_(meta_data),
81         input_data_addrs_(input_data_addrs),
82         output_data_addrs_(output_data_addrs),
83         workspace_addrs_(workspace_addrs) {}
~TbeTaskInfo()84   ~TbeTaskInfo() override {}
85 
stub_func()86   const std::string &stub_func() const { return stub_func_; }
block_dim()87   uint32_t block_dim() const { return block_dim_; }
args()88   const std::vector<uint8_t> &args() const { return args_; }
args_size()89   uint32_t args_size() const { return args_size_; }
sm_desc()90   const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
binary()91   void *binary() const { return binary_; }
binary_size()92   uint32_t binary_size() const { return binary_size_; }
meta_data()93   const std::vector<uint8_t> &meta_data() const { return meta_data_; }
input_data_addrs()94   const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
output_data_addrs()95   const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
workspace_addrs()96   const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
97 
SetBinary(void * binary,uint32_t binary_size)98   void SetBinary(void *binary, uint32_t binary_size) {
99     binary_ = binary;
100     binary_size_ = binary_size;
101   }
102 
103  private:
104   std::string stub_func_;
105   uint32_t block_dim_;
106   std::vector<uint8_t> args_;
107   uint32_t args_size_;
108   std::vector<uint8_t> sm_desc_;
109   void *binary_;
110   uint32_t binary_size_;
111   std::vector<uint8_t> meta_data_;
112   std::vector<void *> input_data_addrs_;
113   std::vector<void *> output_data_addrs_;
114   std::vector<void *> workspace_addrs_;
115 };
116 
117 class AicpuTaskInfo : public TaskInfo {
118  public:
AicpuTaskInfo(const std::string & op_name,uint32_t stream_id,const std::string & so_name,const std::string & kernel_name,const std::string & node_def,const std::string & ext_info,const std::vector<void * > & input_data_addrs,const std::vector<void * > & output_data_addrs,bool dump_flag)119   AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &so_name,
120                 const std::string &kernel_name, const std::string &node_def, const std::string &ext_info,
121                 const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs,
122                 bool dump_flag)
123       : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
124         so_name_(so_name),
125         kernel_name_(kernel_name),
126         node_def_(node_def),
127         ext_info_(ext_info),
128         input_data_addrs_(input_data_addrs),
129         output_data_addrs_(output_data_addrs) {}
~AicpuTaskInfo()130   ~AicpuTaskInfo() override {}
131 
so_name()132   const std::string &so_name() const { return so_name_; }
kernel_name()133   const std::string &kernel_name() const { return kernel_name_; }
node_def()134   const std::string &node_def() const { return node_def_; }
input_data_addrs()135   const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
output_data_addrs()136   const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
ext_info()137   const std::string &ext_info() const { return ext_info_; }
138 
139  private:
140   std::string so_name_;
141   std::string kernel_name_;
142   std::string node_def_;
143   std::string ext_info_;
144   std::vector<void *> input_data_addrs_;
145   std::vector<void *> output_data_addrs_;
146 };
147 
148 class LabelSetTaskInfo : public TaskInfo {
149  public:
LabelSetTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t label_id)150   LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
151       : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo()152   ~LabelSetTaskInfo() override {}
label_id()153   uint32_t label_id() const { return label_id_; }
154 
155  private:
156   uint32_t label_id_;
157 };
158 
159 class LabelGotoTaskInfo : public TaskInfo {
160  public:
LabelGotoTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t label_id)161   LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
162       : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo()163   ~LabelGotoTaskInfo() override {}
label_id()164   uint32_t label_id() const { return label_id_; }
165 
166  private:
167   uint32_t label_id_;
168 };
169 
170 class LabelSwitchTaskInfo : public TaskInfo {
171  public:
LabelSwitchTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t label_size,const std::vector<uint32_t> & label_list,void * cond)172   LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
173                       const std::vector<uint32_t> &label_list, void *cond)
174       : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
175         label_size_(label_size),
176         label_list_(label_list),
177         cond_(cond) {}
~LabelSwitchTaskInfo()178   ~LabelSwitchTaskInfo() override {}
label_size()179   uint32_t label_size() const { return label_size_; }
label_list()180   const std::vector<uint32_t> &label_list() const { return label_list_; }
cond()181   void *cond() const { return cond_; }
182 
183  private:
184   uint32_t label_size_;
185   std::vector<uint32_t> label_list_;
186   void *cond_;
187 };
188 
189 class EventTaskInfo : public TaskInfo {
190  public:
event_id()191   uint32_t event_id() const { return event_id_; }
192 
193  protected:
EventTaskInfo(const std::string & op_name,uint32_t stream_id,TaskInfoType type,uint32_t event_id)194   EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
195       : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
~EventTaskInfo()196   ~EventTaskInfo() override {}
197 
198   uint32_t event_id_;
199 };
200 
201 class EventRecordTaskInfo : public EventTaskInfo {
202  public:
EventRecordTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t event_id)203   EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
204       : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo()205   ~EventRecordTaskInfo() override {}
206 };
207 
208 class EventWaitTaskInfo : public EventTaskInfo {
209  public:
EventWaitTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t event_id)210   EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
211       : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo()212   ~EventWaitTaskInfo() override {}
213 };
214 
215 class FusionStartTaskInfo : public TaskInfo {
216  public:
FusionStartTaskInfo(const std::string & op_name,uint32_t stream_id)217   explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
218       : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
~FusionStartTaskInfo()219   ~FusionStartTaskInfo() override {}
220 };
221 
222 class FusionEndTaskInfo : public TaskInfo {
223  public:
FusionEndTaskInfo(const std::string & op_name,uint32_t stream_id)224   explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
225       : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
~FusionEndTaskInfo()226   ~FusionEndTaskInfo() override {}
227 };
228 
229 class HcclTaskInfo : public TaskInfo {
230  public:
HcclTaskInfo(const std::string & op_name,uint32_t stream_id,const std::string hccl_type,void * input_data_addr,void * output_data_addr,void * workspace_addr,int64_t workspace_size,int64_t hccl_stream_num,const std::vector<uint8_t> & private_def,void * ops_kernel_store,int32_t count,int64_t root_id,int64_t op_type,int64_t data_type,const std::string & group,bool dump_flag)231   HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
232                void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
233                const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
234                int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
235       : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
236         hccl_type_(hccl_type),
237         input_data_addr_(input_data_addr),
238         output_data_addr_(output_data_addr),
239         workspace_addr_(workspace_addr),
240         workspace_size_(workspace_size),
241         hccl_stream_num_(hccl_stream_num),
242         private_def_(private_def),
243         ops_kernel_store_(ops_kernel_store),
244         count_(count),
245         root_id_(root_id),
246         op_type_(op_type),
247         data_type_(data_type),
248         group_(group) {}
~HcclTaskInfo()249   ~HcclTaskInfo() override {}
250 
hccl_type()251   const std::string &hccl_type() const { return hccl_type_; }
input_data_addr()252   void *input_data_addr() const { return input_data_addr_; }
output_data_addr()253   void *output_data_addr() const { return output_data_addr_; }
workspace_addr()254   void *workspace_addr() const { return workspace_addr_; }
workspace_size()255   int64_t workspace_size() const { return workspace_size_; }
hccl_stream_num()256   int64_t hccl_stream_num() const { return hccl_stream_num_; }
private_def()257   const std::vector<uint8_t> &private_def() const { return private_def_; }
ops_kernel_store()258   void *ops_kernel_store() const { return ops_kernel_store_; }
count()259   int32_t count() const { return count_; }
root_id()260   int64_t root_id() const { return root_id_; }
op_type()261   int64_t op_type() const { return op_type_; }
data_type()262   int64_t data_type() const { return data_type_; }
group()263   const std::string &group() const { return group_; }
264 
265  private:
266   std::string hccl_type_;
267   void *input_data_addr_;
268   void *output_data_addr_;
269   void *workspace_addr_;
270   int64_t workspace_size_;
271   int64_t hccl_stream_num_;
272   std::vector<uint8_t> private_def_;
273   void *ops_kernel_store_;
274   int32_t count_;
275   int64_t root_id_;
276   int64_t op_type_;
277   int64_t data_type_;
278   std::string group_;
279 };
280 
281 class ProfilerTraceTaskInfo : public TaskInfo {
282  public:
ProfilerTraceTaskInfo(const std::string & op_name,uint32_t stream_id,uint64_t log_id,bool notify,uint32_t flat)283   ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
284       : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
285         log_id_(log_id),
286         notify_(notify),
287         flat_(flat) {}
~ProfilerTraceTaskInfo()288   ~ProfilerTraceTaskInfo() override {}
289 
log_id()290   uint64_t log_id() const { return log_id_; }
notify()291   bool notify() const { return notify_; }
flat()292   uint32_t flat() const { return flat_; }
293 
294  private:
295   uint64_t log_id_;
296   bool notify_;
297   uint32_t flat_;
298 };
299 
300 class MemcpyAsyncTaskInfo : public TaskInfo {
301  public:
MemcpyAsyncTaskInfo(const std::string & op_name,uint32_t stream_id,void * dst,uint64_t dst_max,void * src,uint64_t count,uint32_t kind,bool dump_flag)302   MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
303                       uint64_t count, uint32_t kind, bool dump_flag)
304       : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
305         dst_(dst),
306         dst_max_(dst_max),
307         src_(src),
308         count_(count),
309         kind_(kind) {}
~MemcpyAsyncTaskInfo()310   ~MemcpyAsyncTaskInfo() override {}
311 
dst()312   void *dst() const { return dst_; }
dst_max()313   uint64_t dst_max() const { return dst_max_; }
src()314   void *src() const { return src_; }
count()315   uint64_t count() const { return count_; }
kind()316   uint32_t kind() const { return kind_; }
317 
318  private:
319   void *dst_;
320   uint64_t dst_max_;
321   void *src_;
322   uint64_t count_;
323   int32_t kind_;
324 };
325 
326 class StreamSwitchTaskInfo : public TaskInfo {
327  public:
StreamSwitchTaskInfo(const std::string & op_name,uint32_t stream_id,int64_t true_stream_id,void * input_addr,void * value_addr,int64_t cond,int64_t data_type)328   StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
329                        void *value_addr, int64_t cond, int64_t data_type)
330       : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
331         true_stream_id_(true_stream_id),
332         input_addr_(input_addr),
333         value_addr_(value_addr),
334         cond_(cond),
335         data_type_(data_type) {}
~StreamSwitchTaskInfo()336   ~StreamSwitchTaskInfo() override {}
337 
true_stream_id()338   int64_t true_stream_id() const { return true_stream_id_; }
input_addr()339   void *input_addr() const { return input_addr_; }
value_addr()340   void *value_addr() const { return value_addr_; }
cond()341   int64_t cond() const { return cond_; }
data_type()342   int64_t data_type() const { return data_type_; }
343 
344  private:
345   int64_t true_stream_id_;
346   void *input_addr_;
347   void *value_addr_;
348   int64_t cond_;
349   int64_t data_type_;
350 };
351 
352 class StreamActiveTaskInfo : public TaskInfo {
353  public:
StreamActiveTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t active_stream_id)354   StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
355       : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo()356   ~StreamActiveTaskInfo() override {}
357 
active_stream_id()358   uint32_t active_stream_id() const { return active_stream_id_; }
359 
360  private:
361   uint32_t active_stream_id_;
362 };
363 }  // namespace mindspore::ge::model_runner
364 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
365