1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ 19 20 #include <stdint.h> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 namespace mindspore::ge::model_runner { 27 enum TaskInfoType { 28 CCE = 0, 29 TBE, 30 AICPU, 31 LABEL_SET, 32 LABEL_SWITCH, 33 LABEL_GOTO, 34 EVENT_RECORD, 35 EVENT_WAIT, 36 FUSION_START, 37 FUSION_END, 38 HCCL, 39 PROFILER_TRACE, 40 MEMCPY_ASYNC, 41 STREAM_SWITCH, 42 STREAM_ACTIVE, 43 // Insert new task type here 44 REVSERVED = 23 45 }; 46 47 class TaskInfo { 48 public: ~TaskInfo()49 virtual ~TaskInfo() {} stream_id()50 uint32_t stream_id() const { return stream_id_; } type()51 TaskInfoType type() const { return type_; } op_name()52 std::string op_name() const { return op_name_; } dump_flag()53 bool dump_flag() const { return dump_flag_; } 54 55 protected: TaskInfo(const std::string & op_name,uint32_t stream_id,TaskInfoType type,bool dump_flag)56 TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) 57 : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} 58 59 private: 60 std::string op_name_; 61 uint32_t stream_id_; 62 TaskInfoType type_; 63 bool dump_flag_; 64 }; 65 66 class TbeTaskInfo : public TaskInfo { 67 public: TbeTaskInfo(const std::string & op_name,uint32_t stream_id,const std::string & stub_func,uint32_t block_dim,const std::vector<uint8_t> & args,uint32_t args_size,const std::vector<uint8_t> & sm_desc,void * binary,uint32_t binary_size,const std::vector<uint8_t> & meta_data,const std::vector<void * > & input_data_addrs,const std::vector<void * > & output_data_addrs,const std::vector<void * > & workspace_addrs,bool dump_flag)68 TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, 69 const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, 70 uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, 71 const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) 72 : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), 73 stub_func_(stub_func), 74 block_dim_(block_dim), 75 args_(args), 76 args_size_(args_size), 77 sm_desc_(sm_desc), 78 binary_(binary), 79 binary_size_(binary_size), 80 meta_data_(meta_data), 81 input_data_addrs_(input_data_addrs), 82 output_data_addrs_(output_data_addrs), 83 workspace_addrs_(workspace_addrs) {} ~TbeTaskInfo()84 ~TbeTaskInfo() override {} 85 stub_func()86 const std::string &stub_func() const { return stub_func_; } block_dim()87 uint32_t block_dim() const { return block_dim_; } args()88 const std::vector<uint8_t> &args() const { return args_; } args_size()89 uint32_t args_size() const { return args_size_; } sm_desc()90 const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } binary()91 void *binary() const { return binary_; } binary_size()92 uint32_t binary_size() const { return binary_size_; } meta_data()93 const std::vector<uint8_t> &meta_data() const { return meta_data_; } input_data_addrs()94 const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } output_data_addrs()95 const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } workspace_addrs()96 const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; } 97 SetBinary(void * binary,uint32_t binary_size)98 void SetBinary(void *binary, uint32_t binary_size) { 99 binary_ = binary; 100 binary_size_ = binary_size; 101 } 102 103 private: 104 std::string stub_func_; 105 uint32_t block_dim_; 106 std::vector<uint8_t> args_; 107 uint32_t args_size_; 108 std::vector<uint8_t> sm_desc_; 109 void *binary_; 110 uint32_t binary_size_; 111 std::vector<uint8_t> meta_data_; 112 std::vector<void *> input_data_addrs_; 113 std::vector<void *> output_data_addrs_; 114 std::vector<void *> workspace_addrs_; 115 }; 116 117 class AicpuTaskInfo : public TaskInfo { 118 public: AicpuTaskInfo(const std::string & op_name,uint32_t stream_id,const std::string & so_name,const std::string & kernel_name,const std::string & node_def,const std::string & ext_info,const std::vector<void * > & input_data_addrs,const std::vector<void * > & output_data_addrs,bool dump_flag)119 AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &so_name, 120 const std::string &kernel_name, const std::string &node_def, const std::string &ext_info, 121 const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs, 122 bool dump_flag) 123 : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), 124 so_name_(so_name), 125 kernel_name_(kernel_name), 126 node_def_(node_def), 127 ext_info_(ext_info), 128 input_data_addrs_(input_data_addrs), 129 output_data_addrs_(output_data_addrs) {} ~AicpuTaskInfo()130 ~AicpuTaskInfo() override {} 131 so_name()132 const std::string &so_name() const { return so_name_; } kernel_name()133 const std::string &kernel_name() const { return kernel_name_; } node_def()134 const std::string &node_def() const { return node_def_; } input_data_addrs()135 const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } output_data_addrs()136 const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } ext_info()137 const std::string &ext_info() const { return ext_info_; } 138 139 private: 140 std::string so_name_; 141 std::string kernel_name_; 142 std::string node_def_; 143 std::string ext_info_; 144 std::vector<void *> input_data_addrs_; 145 std::vector<void *> output_data_addrs_; 146 }; 147 148 class LabelSetTaskInfo : public TaskInfo { 149 public: LabelSetTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t label_id)150 LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) 151 : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} ~LabelSetTaskInfo()152 ~LabelSetTaskInfo() override {} label_id()153 uint32_t label_id() const { return label_id_; } 154 155 private: 156 uint32_t label_id_; 157 }; 158 159 class LabelGotoTaskInfo : public TaskInfo { 160 public: LabelGotoTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t label_id)161 LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) 162 : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} ~LabelGotoTaskInfo()163 ~LabelGotoTaskInfo() override {} label_id()164 uint32_t label_id() const { return label_id_; } 165 166 private: 167 uint32_t label_id_; 168 }; 169 170 class LabelSwitchTaskInfo : public TaskInfo { 171 public: LabelSwitchTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t label_size,const std::vector<uint32_t> & label_list,void * cond)172 LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, 173 const std::vector<uint32_t> &label_list, void *cond) 174 : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), 175 label_size_(label_size), 176 label_list_(label_list), 177 cond_(cond) {} ~LabelSwitchTaskInfo()178 ~LabelSwitchTaskInfo() override {} label_size()179 uint32_t label_size() const { return label_size_; } label_list()180 const std::vector<uint32_t> &label_list() const { return label_list_; } cond()181 void *cond() const { return cond_; } 182 183 private: 184 uint32_t label_size_; 185 std::vector<uint32_t> label_list_; 186 void *cond_; 187 }; 188 189 class EventTaskInfo : public TaskInfo { 190 public: event_id()191 uint32_t event_id() const { return event_id_; } 192 193 protected: EventTaskInfo(const std::string & op_name,uint32_t stream_id,TaskInfoType type,uint32_t event_id)194 EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) 195 : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} ~EventTaskInfo()196 ~EventTaskInfo() override {} 197 198 uint32_t event_id_; 199 }; 200 201 class EventRecordTaskInfo : public EventTaskInfo { 202 public: EventRecordTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t event_id)203 EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) 204 : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} ~EventRecordTaskInfo()205 ~EventRecordTaskInfo() override {} 206 }; 207 208 class EventWaitTaskInfo : public EventTaskInfo { 209 public: EventWaitTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t event_id)210 EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) 211 : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} ~EventWaitTaskInfo()212 ~EventWaitTaskInfo() override {} 213 }; 214 215 class FusionStartTaskInfo : public TaskInfo { 216 public: FusionStartTaskInfo(const std::string & op_name,uint32_t stream_id)217 explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) 218 : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} ~FusionStartTaskInfo()219 ~FusionStartTaskInfo() override {} 220 }; 221 222 class FusionEndTaskInfo : public TaskInfo { 223 public: FusionEndTaskInfo(const std::string & op_name,uint32_t stream_id)224 explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) 225 : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} ~FusionEndTaskInfo()226 ~FusionEndTaskInfo() override {} 227 }; 228 229 class HcclTaskInfo : public TaskInfo { 230 public: HcclTaskInfo(const std::string & op_name,uint32_t stream_id,const std::string hccl_type,void * input_data_addr,void * output_data_addr,void * workspace_addr,int64_t workspace_size,int64_t hccl_stream_num,const std::vector<uint8_t> & private_def,void * ops_kernel_store,int32_t count,int64_t root_id,int64_t op_type,int64_t data_type,const std::string & group,bool dump_flag)231 HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, 232 void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, 233 const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, 234 int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) 235 : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), 236 hccl_type_(hccl_type), 237 input_data_addr_(input_data_addr), 238 output_data_addr_(output_data_addr), 239 workspace_addr_(workspace_addr), 240 workspace_size_(workspace_size), 241 hccl_stream_num_(hccl_stream_num), 242 private_def_(private_def), 243 ops_kernel_store_(ops_kernel_store), 244 count_(count), 245 root_id_(root_id), 246 op_type_(op_type), 247 data_type_(data_type), 248 group_(group) {} ~HcclTaskInfo()249 ~HcclTaskInfo() override {} 250 hccl_type()251 const std::string &hccl_type() const { return hccl_type_; } input_data_addr()252 void *input_data_addr() const { return input_data_addr_; } output_data_addr()253 void *output_data_addr() const { return output_data_addr_; } workspace_addr()254 void *workspace_addr() const { return workspace_addr_; } workspace_size()255 int64_t workspace_size() const { return workspace_size_; } hccl_stream_num()256 int64_t hccl_stream_num() const { return hccl_stream_num_; } private_def()257 const std::vector<uint8_t> &private_def() const { return private_def_; } ops_kernel_store()258 void *ops_kernel_store() const { return ops_kernel_store_; } count()259 int32_t count() const { return count_; } root_id()260 int64_t root_id() const { return root_id_; } op_type()261 int64_t op_type() const { return op_type_; } data_type()262 int64_t data_type() const { return data_type_; } group()263 const std::string &group() const { return group_; } 264 265 private: 266 std::string hccl_type_; 267 void *input_data_addr_; 268 void *output_data_addr_; 269 void *workspace_addr_; 270 int64_t workspace_size_; 271 int64_t hccl_stream_num_; 272 std::vector<uint8_t> private_def_; 273 void *ops_kernel_store_; 274 int32_t count_; 275 int64_t root_id_; 276 int64_t op_type_; 277 int64_t data_type_; 278 std::string group_; 279 }; 280 281 class ProfilerTraceTaskInfo : public TaskInfo { 282 public: ProfilerTraceTaskInfo(const std::string & op_name,uint32_t stream_id,uint64_t log_id,bool notify,uint32_t flat)283 ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) 284 : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), 285 log_id_(log_id), 286 notify_(notify), 287 flat_(flat) {} ~ProfilerTraceTaskInfo()288 ~ProfilerTraceTaskInfo() override {} 289 log_id()290 uint64_t log_id() const { return log_id_; } notify()291 bool notify() const { return notify_; } flat()292 uint32_t flat() const { return flat_; } 293 294 private: 295 uint64_t log_id_; 296 bool notify_; 297 uint32_t flat_; 298 }; 299 300 class MemcpyAsyncTaskInfo : public TaskInfo { 301 public: MemcpyAsyncTaskInfo(const std::string & op_name,uint32_t stream_id,void * dst,uint64_t dst_max,void * src,uint64_t count,uint32_t kind,bool dump_flag)302 MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, 303 uint64_t count, uint32_t kind, bool dump_flag) 304 : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), 305 dst_(dst), 306 dst_max_(dst_max), 307 src_(src), 308 count_(count), 309 kind_(kind) {} ~MemcpyAsyncTaskInfo()310 ~MemcpyAsyncTaskInfo() override {} 311 dst()312 void *dst() const { return dst_; } dst_max()313 uint64_t dst_max() const { return dst_max_; } src()314 void *src() const { return src_; } count()315 uint64_t count() const { return count_; } kind()316 uint32_t kind() const { return kind_; } 317 318 private: 319 void *dst_; 320 uint64_t dst_max_; 321 void *src_; 322 uint64_t count_; 323 int32_t kind_; 324 }; 325 326 class StreamSwitchTaskInfo : public TaskInfo { 327 public: StreamSwitchTaskInfo(const std::string & op_name,uint32_t stream_id,int64_t true_stream_id,void * input_addr,void * value_addr,int64_t cond,int64_t data_type)328 StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, 329 void *value_addr, int64_t cond, int64_t data_type) 330 : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), 331 true_stream_id_(true_stream_id), 332 input_addr_(input_addr), 333 value_addr_(value_addr), 334 cond_(cond), 335 data_type_(data_type) {} ~StreamSwitchTaskInfo()336 ~StreamSwitchTaskInfo() override {} 337 true_stream_id()338 int64_t true_stream_id() const { return true_stream_id_; } input_addr()339 void *input_addr() const { return input_addr_; } value_addr()340 void *value_addr() const { return value_addr_; } cond()341 int64_t cond() const { return cond_; } data_type()342 int64_t data_type() const { return data_type_; } 343 344 private: 345 int64_t true_stream_id_; 346 void *input_addr_; 347 void *value_addr_; 348 int64_t cond_; 349 int64_t data_type_; 350 }; 351 352 class StreamActiveTaskInfo : public TaskInfo { 353 public: StreamActiveTaskInfo(const std::string & op_name,uint32_t stream_id,uint32_t active_stream_id)354 StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) 355 : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} ~StreamActiveTaskInfo()356 ~StreamActiveTaskInfo() override {} 357 active_stream_id()358 uint32_t active_stream_id() const { return active_stream_id_; } 359 360 private: 361 uint32_t active_stream_id_; 362 }; 363 } // namespace mindspore::ge::model_runner 364 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ 365