1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <memory> 17 #include "common/common_test.h" 18 #define private public 19 #include "runtime/device/ascend/ge_runtime/model_runner.h" 20 #include "runtime/device/ascend/ge_runtime/runtime_model.h" 21 #include "runtime/device/ascend/ge_runtime/task/task_factory.h" 22 #include "runtime/device/ascend/ge_runtime/task/aicpu_task.h" 23 #include "runtime/device/ascend/ge_runtime/task/event_record_task.h" 24 #include "runtime/device/ascend/ge_runtime/task/event_wait_task.h" 25 #include "runtime/device/ascend/ge_runtime/task/hccl_task.h" 26 #include "runtime/device/ascend/ge_runtime/task/label_goto_task.h" 27 #include "runtime/device/ascend/ge_runtime/task/label_manager.h" 28 #include "runtime/device/ascend/ge_runtime/task/label_set_task.h" 29 #include "runtime/device/ascend/ge_runtime/task/label_switch_task.h" 30 #include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h" 31 #include "runtime/device/ascend/ge_runtime/task/profiler_task.h" 32 #include "runtime/device/ascend/ge_runtime/task/stream_active_task.h" 33 #include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h" 34 #include "runtime/device/ascend/ge_runtime/task/tbe_task.h" 35 #undef private 36 #include "common/opskernel/ops_kernel_info_store.h" 37 38 using namespace mindspore::ge::model_runner; 39 using namespace testing; 40 41 class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore { 42 public: 43 ge::Status Initialize(const map<string, string> &) override { return ge::SUCCESS; } 44 ge::Status Finalize() override { return ge::SUCCESS; } 45 void GetAllOpsKernelInfo(std::map<string, ge::OpInfo> &infos) const override {} 46 bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; } 47 ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; } 48 }; 49 50 namespace mindspore { 51 class TestAscendGeRuntime : public UT::Common { 52 public: 53 TestAscendGeRuntime() {} 54 55 private: 56 void TearDown() override { 57 { 58 std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_); 59 HcclTask::model_stream_mapping_.clear(); 60 } 61 } 62 }; 63 64 TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) { 65 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 66 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)}, 67 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 68 ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr); 69 } 70 71 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) { 72 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 73 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 74 {reinterpret_cast<rtEvent_t>(1)}); 75 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( 76 "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)}, 77 std::vector<void *>{reinterpret_cast<void *>(1)}, true); 78 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info); 79 ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr); 80 ASSERT_NO_THROW(task->Distribute()); 81 } 82 83 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) { 84 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 85 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)}, 86 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 87 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( 88 "op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)}, 89 std::vector<void *>{reinterpret_cast<void *>(1)}, true); 90 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info); 91 ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr); 92 ASSERT_NO_THROW(task->Distribute()); 93 } 94 95 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) { 96 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 97 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)}, 98 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 99 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( 100 "op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)}, 101 std::vector<void *>{reinterpret_cast<void *>(1)}, true); 102 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)); 103 } 104 105 TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) { 106 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 107 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 108 {reinterpret_cast<rtEvent_t>(1)}); 109 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 0); 110 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info); 111 ASSERT_TRUE(std::dynamic_pointer_cast<EventRecordTask>(task) != nullptr); 112 ASSERT_NO_THROW(task->Distribute()); 113 } 114 115 TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) { 116 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 117 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 118 {reinterpret_cast<rtEvent_t>(1)}); 119 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 10); 120 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info)); 121 } 122 123 TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) { 124 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 125 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 126 {reinterpret_cast<rtEvent_t>(1)}); 127 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 0); 128 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info); 129 ASSERT_TRUE(std::dynamic_pointer_cast<EventWaitTask>(task) != nullptr); 130 ASSERT_NO_THROW(task->Distribute()); 131 } 132 133 TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) { 134 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 135 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 136 {reinterpret_cast<rtEvent_t>(1)}); 137 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 10); 138 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info)); 139 } 140 141 TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) { 142 MockOpsKernelInfoStore ops_kernel_info_store; 143 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 144 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 145 {reinterpret_cast<rtEvent_t>(1)}); 146 std::shared_ptr<TaskInfo> hccl_task_info = std::make_shared<HcclTaskInfo>( 147 "op_name", 0, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 148 5, std::vector<uint8_t>(6, 7), reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); 149 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, hccl_task_info); 150 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task) != nullptr); 151 ASSERT_NO_THROW(task->Distribute()); 152 } 153 154 TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) { 155 const rtModel_t model = reinterpret_cast<rtModel_t>(0x12345678); 156 const rtStream_t stream = reinterpret_cast<rtStream_t>(0x87654321); 157 constexpr uint32_t stream_id = 0; 158 constexpr int64_t task1_stream_num = 3; 159 constexpr int64_t task2_stream_num = 5; 160 constexpr int64_t task3_stream_num = 4; 161 MockOpsKernelInfoStore ops_kernel_info_store; 162 ModelContext model_context(0, 0, 0, model, reinterpret_cast<rtStream_t>(2), {stream}, 163 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 164 std::shared_ptr<TaskInfo> hccl_task_info_1 = std::make_shared<HcclTaskInfo>( 165 "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), 166 reinterpret_cast<void *>(3), 4, task1_stream_num, std::vector<uint8_t>(6, 7), 167 reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); 168 std::shared_ptr<TaskInfo> hccl_task_info_2 = std::make_shared<HcclTaskInfo>( 169 "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), 170 reinterpret_cast<void *>(3), 4, task2_stream_num, std::vector<uint8_t>(6, 7), 171 reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); 172 std::shared_ptr<TaskInfo> hccl_task_info_3 = std::make_shared<HcclTaskInfo>( 173 "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), 174 reinterpret_cast<void *>(3), 4, task3_stream_num, std::vector<uint8_t>(6, 7), 175 reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); 176 std::shared_ptr<Task> task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1); 177 std::shared_ptr<Task> task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2); 178 std::shared_ptr<Task> task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3); 179 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_1) != nullptr); 180 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_2) != nullptr); 181 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_3) != nullptr); 182 ASSERT_NO_THROW(task_1->Distribute()); 183 ASSERT_NO_THROW(task_2->Distribute()); 184 ASSERT_NO_THROW(task_3->Distribute()); 185 { 186 std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_); 187 auto model_iter = HcclTask::model_stream_mapping_.find(model); 188 ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end()); 189 auto stream_iter = model_iter->second.find(stream_id); 190 ASSERT_NE(stream_iter, model_iter->second.end()); 191 const auto &stream_vec = stream_iter->second; 192 ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num))); 193 for (const auto &s : stream_vec) { 194 auto shared = s.lock(); 195 ASSERT_TRUE(shared != nullptr); 196 } 197 } 198 } 199 200 TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) { 201 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 202 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 203 {reinterpret_cast<rtEvent_t>(1)}); 204 std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0); 205 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); 206 auto label_goto_task = std::dynamic_pointer_cast<LabelGotoTask>(task); 207 ASSERT_TRUE(label_goto_task != nullptr); 208 ASSERT_NO_THROW(task->Distribute()); 209 label_goto_task->index_value_ = new uint8_t[5]; 210 } 211 212 TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) { 213 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 214 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 215 {reinterpret_cast<rtEvent_t>(1)}); 216 std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1); 217 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info)); 218 } 219 220 TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) { 221 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 222 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 223 {reinterpret_cast<rtEvent_t>(1)}); 224 std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0); 225 std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); 226 std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); 227 auto label_goto_task_1 = std::dynamic_pointer_cast<LabelGotoTask>(task1); 228 auto label_goto_task_2 = std::dynamic_pointer_cast<LabelGotoTask>(task2); 229 ASSERT_TRUE(label_goto_task_1 != nullptr); 230 ASSERT_NO_THROW(task1->Distribute()); 231 ASSERT_TRUE(label_goto_task_2 != nullptr); 232 ASSERT_NO_THROW(task2->Distribute()); 233 ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_); 234 } 235 236 TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) { 237 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 238 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 239 {reinterpret_cast<rtEvent_t>(1)}); 240 std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelSetTaskInfo>("op_name", 0, 0); 241 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_set_task_info); 242 ASSERT_TRUE(std::dynamic_pointer_cast<LabelSetTask>(task) != nullptr); 243 ASSERT_NO_THROW(task->Distribute()); 244 } 245 246 TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) { 247 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 248 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, 249 {reinterpret_cast<rtEvent_t>(1)}); 250 std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1); 251 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info)); 252 } 253 254 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) { 255 ModelContext model_context( 256 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 257 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 258 std::shared_ptr<TaskInfo> label_switch_task_info = 259 std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1)); 260 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); 261 ASSERT_TRUE(std::dynamic_pointer_cast<LabelSwitchTask>(task) != nullptr); 262 ASSERT_NO_THROW(task->Distribute()); 263 } 264 265 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) { 266 ModelContext model_context( 267 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 268 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 269 std::shared_ptr<TaskInfo> label_switch_task_info = 270 std::make_shared<LabelSwitchTaskInfo>("op_name", 1, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1)); 271 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info)); 272 } 273 274 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) { 275 ModelContext model_context( 276 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 277 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 278 std::shared_ptr<TaskInfo> label_switch_task_info = 279 std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 3, std::vector<uint32_t>{0, 1, 2}, reinterpret_cast<void *>(1)); 280 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info)); 281 } 282 283 TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) { 284 ModelContext model_context( 285 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 286 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 287 std::shared_ptr<TaskInfo> label_switch_task_info = 288 std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1)); 289 std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); 290 std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); 291 auto label_switch_task_1 = std::dynamic_pointer_cast<LabelSwitchTask>(task1); 292 auto label_switch_task_2 = std::dynamic_pointer_cast<LabelSwitchTask>(task2); 293 ASSERT_TRUE(label_switch_task_1 != nullptr); 294 ASSERT_TRUE(label_switch_task_2 != nullptr); 295 ASSERT_NO_THROW(task1->Distribute()); 296 ASSERT_NO_THROW(task2->Distribute()); 297 ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_); 298 } 299 300 TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) { 301 ModelContext model_context( 302 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 303 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 304 std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>( 305 "op_name", 0, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true); 306 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info); 307 ASSERT_TRUE(std::dynamic_pointer_cast<MemcpyAsyncTask>(task) != nullptr); 308 ASSERT_NO_THROW(task->Distribute()); 309 } 310 311 TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) { 312 ModelContext model_context( 313 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 314 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 315 std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>( 316 "op_name", 1, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true); 317 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info)); 318 } 319 320 TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) { 321 ModelContext model_context( 322 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 323 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 324 std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 0, 1, true, 2); 325 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, profiler_task_info); 326 ASSERT_TRUE(std::dynamic_pointer_cast<ProfilerTask>(task) != nullptr); 327 ASSERT_NO_THROW(task->Distribute()); 328 } 329 330 TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) { 331 ModelContext model_context( 332 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, 333 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); 334 std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 1, 1, true, 2); 335 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info)); 336 } 337 338 TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) { 339 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 340 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 341 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 342 {reinterpret_cast<rtEvent_t>(1)}); 343 std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 1); 344 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info); 345 ASSERT_TRUE(std::dynamic_pointer_cast<StreamActiveTask>(task) != nullptr); 346 ASSERT_NO_THROW(task->Distribute()); 347 } 348 349 TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) { 350 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 351 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 352 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 353 {reinterpret_cast<rtEvent_t>(1)}); 354 std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 2); 355 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info)); 356 } 357 358 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) { 359 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 360 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 361 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 362 {reinterpret_cast<rtEvent_t>(1)}); 363 std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>( 364 "op_name", 0, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5); 365 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info); 366 ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr); 367 ASSERT_NO_THROW(task->Distribute()); 368 } 369 370 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) { 371 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 372 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 373 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 374 {reinterpret_cast<rtEvent_t>(1)}); 375 std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>( 376 "op_name", 0, 2, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5); 377 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info); 378 ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr); 379 ASSERT_ANY_THROW(task->Distribute()); 380 } 381 382 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) { 383 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 384 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 385 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 386 {reinterpret_cast<rtEvent_t>(1)}); 387 std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>( 388 "op_name", 2, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5); 389 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info)); 390 } 391 392 TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) { 393 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 394 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 395 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 396 {reinterpret_cast<rtEvent_t>(1)}); 397 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( 398 "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, 399 reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10}, 400 std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, 401 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, 402 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); 403 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info); 404 auto tbe_task = std::dynamic_pointer_cast<TbeTask>(task); 405 ASSERT_TRUE(tbe_task != nullptr); 406 ASSERT_NO_THROW(task->Distribute()); 407 tbe_task->args_ = new uint8_t[5]; 408 } 409 410 TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) { 411 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 412 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 413 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 414 {reinterpret_cast<rtEvent_t>(1)}); 415 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( 416 "op_name", 3, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, 417 reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10}, 418 std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, 419 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, 420 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); 421 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info)); 422 } 423 424 TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) { 425 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 426 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 427 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 428 {reinterpret_cast<rtEvent_t>(1)}); 429 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( 430 "op_name", 0, "", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, reinterpret_cast<void *>(7), 8, 431 std::vector<uint8_t>{9, 10}, std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, 432 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, 433 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); 434 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info); 435 ASSERT_TRUE(std::dynamic_pointer_cast<TbeTask>(task) != nullptr); 436 ASSERT_ANY_THROW(task->Distribute()); 437 } 438 439 TEST_F(TestAscendGeRuntime, test_model_runner_success) { 440 constexpr uint32_t model_id = 0; 441 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), 442 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, 443 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, 444 {reinterpret_cast<rtEvent_t>(1)}); 445 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( 446 "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, 447 reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10}, 448 std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, 449 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, 450 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); 451 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( 452 "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)}, 453 std::vector<void *>{reinterpret_cast<void *>(1)}, true); 454 auto davice_model = 455 std::make_shared<DavinciModel>(std::vector<std::shared_ptr<TaskInfo>>{tbe_task_info, aicpu_task_info}, 456 std::vector<uint32_t>{}, std::vector<uint32_t>{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0); 457 ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model)); 458 auto iter = ModelRunner::Instance().runtime_models_.find(model_id); 459 ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end()); 460 auto &task_list = iter->second->task_list_; 461 task_list.clear(); 462 ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info))); 463 ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info))); 464 ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id)); 465 ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id)); 466 ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id)); 467 ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty()); 468 ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty()); 469 ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty()); 470 ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id)); 471 ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id)); 472 } 473 } // namespace mindspore 474