• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <memory>
17 #include "common/common_test.h"
18 #define private public
19 #include "runtime/device/ascend/ge_runtime/model_runner.h"
20 #include "runtime/device/ascend/ge_runtime/runtime_model.h"
21 #include "runtime/device/ascend/ge_runtime/task/task_factory.h"
22 #include "runtime/device/ascend/ge_runtime/task/aicpu_task.h"
23 #include "runtime/device/ascend/ge_runtime/task/event_record_task.h"
24 #include "runtime/device/ascend/ge_runtime/task/event_wait_task.h"
25 #include "runtime/device/ascend/ge_runtime/task/hccl_task.h"
26 #include "runtime/device/ascend/ge_runtime/task/label_goto_task.h"
27 #include "runtime/device/ascend/ge_runtime/task/label_manager.h"
28 #include "runtime/device/ascend/ge_runtime/task/label_set_task.h"
29 #include "runtime/device/ascend/ge_runtime/task/label_switch_task.h"
30 #include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h"
31 #include "runtime/device/ascend/ge_runtime/task/profiler_task.h"
32 #include "runtime/device/ascend/ge_runtime/task/stream_active_task.h"
33 #include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h"
34 #include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
35 #undef private
36 #include "common/opskernel/ops_kernel_info_store.h"
37 
38 using namespace mindspore::ge::model_runner;
39 using namespace testing;
40 
41 class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore {
42  public:
Initialize(const map<string,string> &)43   ge::Status Initialize(const map<string, string> &) override { return ge::SUCCESS; }
Finalize()44   ge::Status Finalize() override { return ge::SUCCESS; }
GetAllOpsKernelInfo(std::map<string,ge::OpInfo> & infos) const45   void GetAllOpsKernelInfo(std::map<string, ge::OpInfo> &infos) const override {}
CheckSupported(const ge::OpDescPtr & opDescPtr,std::string & un_supported_reason) const46   bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; }
LoadTask(ge::GETaskInfo & task)47   ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; }
48 };
49 
50 namespace mindspore {
51 class TestAscendGeRuntime : public UT::Common {
52  public:
TestAscendGeRuntime()53   TestAscendGeRuntime() {}
54 
55  private:
TearDown()56   void TearDown() override {
57     {
58       std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
59       HcclTask::model_stream_mapping_.clear();
60     }
61   }
62 };
63 
TEST_F(TestAscendGeRuntime,test_task_create_null_task_info_failed)64 TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) {
65   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
66                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
67                              {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
68   ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr);
69 }
70 
TEST_F(TestAscendGeRuntime,test_aicpu_task_create_one_stream_success)71 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) {
72   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
73                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
74                              {reinterpret_cast<rtEvent_t>(1)});
75   std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
76     "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
77     std::vector<void *>{reinterpret_cast<void *>(1)}, true);
78   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
79   ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
80   ASSERT_NO_THROW(task->Distribute());
81 }
82 
TEST_F(TestAscendGeRuntime,test_aicpu_task_create_multi_stream_success)83 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) {
84   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
85                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
86                              {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
87   std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
88     "op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
89     std::vector<void *>{reinterpret_cast<void *>(1)}, true);
90   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
91   ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
92   ASSERT_NO_THROW(task->Distribute());
93 }
94 
TEST_F(TestAscendGeRuntime,test_aicpu_task_create_invalid_stream_id_failed)95 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) {
96   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
97                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
98                              {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
99   std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
100     "op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
101     std::vector<void *>{reinterpret_cast<void *>(1)}, true);
102   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info));
103 }
104 
TEST_F(TestAscendGeRuntime,test_event_record_task_create_success)105 TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) {
106   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
107                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
108                              {reinterpret_cast<rtEvent_t>(1)});
109   std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 0);
110   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
111   ASSERT_TRUE(std::dynamic_pointer_cast<EventRecordTask>(task) != nullptr);
112   ASSERT_NO_THROW(task->Distribute());
113 }
114 
TEST_F(TestAscendGeRuntime,test_event_record_task_create_invalid_event_id_failed)115 TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) {
116   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
117                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
118                              {reinterpret_cast<rtEvent_t>(1)});
119   std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 10);
120   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
121 }
122 
TEST_F(TestAscendGeRuntime,test_event_wait_task_create_success)123 TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) {
124   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
125                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
126                              {reinterpret_cast<rtEvent_t>(1)});
127   std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 0);
128   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
129   ASSERT_TRUE(std::dynamic_pointer_cast<EventWaitTask>(task) != nullptr);
130   ASSERT_NO_THROW(task->Distribute());
131 }
132 
TEST_F(TestAscendGeRuntime,test_event_wait_task_create_invalid_event_id_failed)133 TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) {
134   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
135                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
136                              {reinterpret_cast<rtEvent_t>(1)});
137   std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 10);
138   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
139 }
140 
TEST_F(TestAscendGeRuntime,test_hccl_task_create_success)141 TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) {
142   MockOpsKernelInfoStore ops_kernel_info_store;
143   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
144                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
145                              {reinterpret_cast<rtEvent_t>(1)});
146   std::shared_ptr<TaskInfo> hccl_task_info = std::make_shared<HcclTaskInfo>(
147     "op_name", 0, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4,
148     5, std::vector<uint8_t>(6, 7), reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
149   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, hccl_task_info);
150   ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task) != nullptr);
151   ASSERT_NO_THROW(task->Distribute());
152 }
153 
TEST_F(TestAscendGeRuntime,test_hccl_task_create_stream_reuse_success)154 TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) {
155   const rtModel_t model = reinterpret_cast<rtModel_t>(0x12345678);
156   const rtStream_t stream = reinterpret_cast<rtStream_t>(0x87654321);
157   constexpr uint32_t stream_id = 0;
158   constexpr int64_t task1_stream_num = 3;
159   constexpr int64_t task2_stream_num = 5;
160   constexpr int64_t task3_stream_num = 4;
161   MockOpsKernelInfoStore ops_kernel_info_store;
162   ModelContext model_context(0, 0, 0, model, reinterpret_cast<rtStream_t>(2), {stream},
163                              {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
164   std::shared_ptr<TaskInfo> hccl_task_info_1 = std::make_shared<HcclTaskInfo>(
165     "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
166     reinterpret_cast<void *>(3), 4, task1_stream_num, std::vector<uint8_t>(6, 7),
167     reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
168   std::shared_ptr<TaskInfo> hccl_task_info_2 = std::make_shared<HcclTaskInfo>(
169     "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
170     reinterpret_cast<void *>(3), 4, task2_stream_num, std::vector<uint8_t>(6, 7),
171     reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
172   std::shared_ptr<TaskInfo> hccl_task_info_3 = std::make_shared<HcclTaskInfo>(
173     "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
174     reinterpret_cast<void *>(3), 4, task3_stream_num, std::vector<uint8_t>(6, 7),
175     reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
176   std::shared_ptr<Task> task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1);
177   std::shared_ptr<Task> task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2);
178   std::shared_ptr<Task> task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3);
179   ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_1) != nullptr);
180   ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_2) != nullptr);
181   ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_3) != nullptr);
182   ASSERT_NO_THROW(task_1->Distribute());
183   ASSERT_NO_THROW(task_2->Distribute());
184   ASSERT_NO_THROW(task_3->Distribute());
185   {
186     std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
187     auto model_iter = HcclTask::model_stream_mapping_.find(model);
188     ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end());
189     auto stream_iter = model_iter->second.find(stream_id);
190     ASSERT_NE(stream_iter, model_iter->second.end());
191     const auto &stream_vec = stream_iter->second;
192     ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num)));
193     for (const auto &s : stream_vec) {
194       auto shared = s.lock();
195       ASSERT_TRUE(shared != nullptr);
196     }
197   }
198 }
199 
TEST_F(TestAscendGeRuntime,test_label_goto_task_create_success)200 TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) {
201   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
202                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
203                              {reinterpret_cast<rtEvent_t>(1)});
204   std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
205   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
206   auto label_goto_task = std::dynamic_pointer_cast<LabelGotoTask>(task);
207   ASSERT_TRUE(label_goto_task != nullptr);
208   ASSERT_NO_THROW(task->Distribute());
209   label_goto_task->index_value_ = new uint8_t[5];
210 }
211 
TEST_F(TestAscendGeRuntime,test_label_goto_task_create_invalid_label_id_failed)212 TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) {
213   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
214                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
215                              {reinterpret_cast<rtEvent_t>(1)});
216   std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
217   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info));
218 }
219 
TEST_F(TestAscendGeRuntime,test_label_goto_task_reuse_success)220 TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) {
221   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
222                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
223                              {reinterpret_cast<rtEvent_t>(1)});
224   std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
225   std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
226   std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
227   auto label_goto_task_1 = std::dynamic_pointer_cast<LabelGotoTask>(task1);
228   auto label_goto_task_2 = std::dynamic_pointer_cast<LabelGotoTask>(task2);
229   ASSERT_TRUE(label_goto_task_1 != nullptr);
230   ASSERT_NO_THROW(task1->Distribute());
231   ASSERT_TRUE(label_goto_task_2 != nullptr);
232   ASSERT_NO_THROW(task2->Distribute());
233   ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_);
234 }
235 
TEST_F(TestAscendGeRuntime,test_label_set_task_create_success)236 TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) {
237   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
238                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
239                              {reinterpret_cast<rtEvent_t>(1)});
240   std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelSetTaskInfo>("op_name", 0, 0);
241   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_set_task_info);
242   ASSERT_TRUE(std::dynamic_pointer_cast<LabelSetTask>(task) != nullptr);
243   ASSERT_NO_THROW(task->Distribute());
244 }
245 
TEST_F(TestAscendGeRuntime,test_label_set_task_create_invalid_label_id_failed)246 TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) {
247   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
248                              {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
249                              {reinterpret_cast<rtEvent_t>(1)});
250   std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
251   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info));
252 }
253 
TEST_F(TestAscendGeRuntime,test_label_switch_task_create_success)254 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) {
255   ModelContext model_context(
256     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
257     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
258   std::shared_ptr<TaskInfo> label_switch_task_info =
259     std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
260   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
261   ASSERT_TRUE(std::dynamic_pointer_cast<LabelSwitchTask>(task) != nullptr);
262   ASSERT_NO_THROW(task->Distribute());
263 }
264 
TEST_F(TestAscendGeRuntime,test_label_switch_task_create_invalid_stream_id_failed)265 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) {
266   ModelContext model_context(
267     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
268     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
269   std::shared_ptr<TaskInfo> label_switch_task_info =
270     std::make_shared<LabelSwitchTaskInfo>("op_name", 1, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
271   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
272 }
273 
TEST_F(TestAscendGeRuntime,test_label_switch_task_create_invalid_label_id_failed)274 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) {
275   ModelContext model_context(
276     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
277     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
278   std::shared_ptr<TaskInfo> label_switch_task_info =
279     std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 3, std::vector<uint32_t>{0, 1, 2}, reinterpret_cast<void *>(1));
280   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
281 }
282 
TEST_F(TestAscendGeRuntime,test_label_switch_task_reuse_success)283 TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) {
284   ModelContext model_context(
285     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
286     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
287   std::shared_ptr<TaskInfo> label_switch_task_info =
288     std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
289   std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
290   std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
291   auto label_switch_task_1 = std::dynamic_pointer_cast<LabelSwitchTask>(task1);
292   auto label_switch_task_2 = std::dynamic_pointer_cast<LabelSwitchTask>(task2);
293   ASSERT_TRUE(label_switch_task_1 != nullptr);
294   ASSERT_TRUE(label_switch_task_2 != nullptr);
295   ASSERT_NO_THROW(task1->Distribute());
296   ASSERT_NO_THROW(task2->Distribute());
297   ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_);
298 }
299 
TEST_F(TestAscendGeRuntime,test_memcpy_async_task_create_success)300 TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) {
301   ModelContext model_context(
302     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
303     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
304   std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
305     "op_name", 0, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
306   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info);
307   ASSERT_TRUE(std::dynamic_pointer_cast<MemcpyAsyncTask>(task) != nullptr);
308   ASSERT_NO_THROW(task->Distribute());
309 }
310 
TEST_F(TestAscendGeRuntime,test_memcpy_async_task_create_invalid_stream_id_failed)311 TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) {
312   ModelContext model_context(
313     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
314     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
315   std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
316     "op_name", 1, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
317   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info));
318 }
319 
TEST_F(TestAscendGeRuntime,test_profiler_task_create_success)320 TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) {
321   ModelContext model_context(
322     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
323     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
324   std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 0, 1, true, 2);
325   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, profiler_task_info);
326   ASSERT_TRUE(std::dynamic_pointer_cast<ProfilerTask>(task) != nullptr);
327   ASSERT_NO_THROW(task->Distribute());
328 }
329 
TEST_F(TestAscendGeRuntime,test_profiler_task_create_invalid_stream_id_failed)330 TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) {
331   ModelContext model_context(
332     0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
333     {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
334   std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 1, 1, true, 2);
335   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info));
336 }
337 
TEST_F(TestAscendGeRuntime,test_stream_active_task_create_success)338 TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) {
339   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
340                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
341                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
342                              {reinterpret_cast<rtEvent_t>(1)});
343   std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 1);
344   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info);
345   ASSERT_TRUE(std::dynamic_pointer_cast<StreamActiveTask>(task) != nullptr);
346   ASSERT_NO_THROW(task->Distribute());
347 }
348 
TEST_F(TestAscendGeRuntime,test_stream_active_task_create_invalid_active_stream_id_failed)349 TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) {
350   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
351                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
352                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
353                              {reinterpret_cast<rtEvent_t>(1)});
354   std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 2);
355   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info));
356 }
357 
TEST_F(TestAscendGeRuntime,test_stream_switch_task_create_success)358 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) {
359   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
360                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
361                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
362                              {reinterpret_cast<rtEvent_t>(1)});
363   std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
364     "op_name", 0, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
365   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
366   ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
367   ASSERT_NO_THROW(task->Distribute());
368 }
369 
TEST_F(TestAscendGeRuntime,test_stream_switch_task_create_invalid_true_stream_id_failed)370 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) {
371   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
372                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
373                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
374                              {reinterpret_cast<rtEvent_t>(1)});
375   std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
376     "op_name", 0, 2, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
377   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
378   ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
379   ASSERT_ANY_THROW(task->Distribute());
380 }
381 
TEST_F(TestAscendGeRuntime,test_stream_switch_task_create_invalid_stream_id_failed)382 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) {
383   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
384                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
385                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
386                              {reinterpret_cast<rtEvent_t>(1)});
387   std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
388     "op_name", 2, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
389   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info));
390 }
391 
TEST_F(TestAscendGeRuntime,test_tbe_task_create_success)392 TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) {
393   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
394                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
395                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
396                              {reinterpret_cast<rtEvent_t>(1)});
397   std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
398     "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
399     reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
400     std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
401     std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
402     std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
403   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
404   auto tbe_task = std::dynamic_pointer_cast<TbeTask>(task);
405   ASSERT_TRUE(tbe_task != nullptr);
406   ASSERT_NO_THROW(task->Distribute());
407   tbe_task->args_ = new uint8_t[5];
408 }
409 
TEST_F(TestAscendGeRuntime,test_tbe_task_create_invalid_stream_id_failed)410 TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) {
411   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
412                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
413                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
414                              {reinterpret_cast<rtEvent_t>(1)});
415   std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
416     "op_name", 3, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
417     reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
418     std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
419     std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
420     std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
421   ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info));
422 }
423 
TEST_F(TestAscendGeRuntime,test_tbe_task_create_empty_stub_func_failed)424 TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) {
425   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
426                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
427                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
428                              {reinterpret_cast<rtEvent_t>(1)});
429   std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
430     "op_name", 0, "", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, reinterpret_cast<void *>(7), 8,
431     std::vector<uint8_t>{9, 10}, std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
432     std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
433     std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
434   std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
435   ASSERT_TRUE(std::dynamic_pointer_cast<TbeTask>(task) != nullptr);
436   ASSERT_ANY_THROW(task->Distribute());
437 }
438 
TEST_F(TestAscendGeRuntime,test_model_runner_success)439 TEST_F(TestAscendGeRuntime, test_model_runner_success) {
440   constexpr uint32_t model_id = 0;
441   ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
442                              {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
443                              {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
444                              {reinterpret_cast<rtEvent_t>(1)});
445   std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
446     "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
447     reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
448     std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
449     std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
450     std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
451   std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
452     "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
453     std::vector<void *>{reinterpret_cast<void *>(1)}, true);
454   auto davice_model =
455     std::make_shared<DavinciModel>(std::vector<std::shared_ptr<TaskInfo>>{tbe_task_info, aicpu_task_info},
456                                    std::vector<uint32_t>{}, std::vector<uint32_t>{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0);
457   ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model));
458   auto iter = ModelRunner::Instance().runtime_models_.find(model_id);
459   ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end());
460   auto &task_list = iter->second->task_list_;
461   task_list.clear();
462   ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info)));
463   ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)));
464   ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id));
465   ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id));
466   ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id));
467   ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty());
468   ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty());
469   ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty());
470   ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id));
471   ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id));
472 }
473 }  // namespace mindspore
474