1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17 #include "common/common_test.h"
18 #define private public
19 #include "runtime/device/ascend/ge_runtime/model_runner.h"
20 #include "runtime/device/ascend/ge_runtime/runtime_model.h"
21 #include "runtime/device/ascend/ge_runtime/task/task_factory.h"
22 #include "runtime/device/ascend/ge_runtime/task/aicpu_task.h"
23 #include "runtime/device/ascend/ge_runtime/task/event_record_task.h"
24 #include "runtime/device/ascend/ge_runtime/task/event_wait_task.h"
25 #include "runtime/device/ascend/ge_runtime/task/hccl_task.h"
26 #include "runtime/device/ascend/ge_runtime/task/label_goto_task.h"
27 #include "runtime/device/ascend/ge_runtime/task/label_manager.h"
28 #include "runtime/device/ascend/ge_runtime/task/label_set_task.h"
29 #include "runtime/device/ascend/ge_runtime/task/label_switch_task.h"
30 #include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h"
31 #include "runtime/device/ascend/ge_runtime/task/profiler_task.h"
32 #include "runtime/device/ascend/ge_runtime/task/stream_active_task.h"
33 #include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h"
34 #include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
35 #undef private
36 #include "common/opskernel/ops_kernel_info_store.h"
37
38 using namespace mindspore::ge::model_runner;
39 using namespace testing;
40
41 class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore {
42 public:
Initialize(const map<string,string> &)43 ge::Status Initialize(const map<string, string> &) override { return ge::SUCCESS; }
Finalize()44 ge::Status Finalize() override { return ge::SUCCESS; }
GetAllOpsKernelInfo(std::map<string,ge::OpInfo> & infos) const45 void GetAllOpsKernelInfo(std::map<string, ge::OpInfo> &infos) const override {}
CheckSupported(const ge::OpDescPtr & opDescPtr,std::string & un_supported_reason) const46 bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; }
LoadTask(ge::GETaskInfo & task)47 ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; }
48 };
49
50 namespace mindspore {
51 class TestAscendGeRuntime : public UT::Common {
52 public:
TestAscendGeRuntime()53 TestAscendGeRuntime() {}
54
55 private:
TearDown()56 void TearDown() override {
57 {
58 std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
59 HcclTask::model_stream_mapping_.clear();
60 }
61 }
62 };
63
TEST_F(TestAscendGeRuntime,test_task_create_null_task_info_failed)64 TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) {
65 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
66 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
67 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
68 ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr);
69 }
70
TEST_F(TestAscendGeRuntime,test_aicpu_task_create_one_stream_success)71 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) {
72 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
73 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
74 {reinterpret_cast<rtEvent_t>(1)});
75 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
76 "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
77 std::vector<void *>{reinterpret_cast<void *>(1)}, true);
78 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
79 ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
80 ASSERT_NO_THROW(task->Distribute());
81 }
82
TEST_F(TestAscendGeRuntime,test_aicpu_task_create_multi_stream_success)83 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) {
84 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
85 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
86 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
87 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
88 "op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
89 std::vector<void *>{reinterpret_cast<void *>(1)}, true);
90 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
91 ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
92 ASSERT_NO_THROW(task->Distribute());
93 }
94
TEST_F(TestAscendGeRuntime,test_aicpu_task_create_invalid_stream_id_failed)95 TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) {
96 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
97 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
98 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
99 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
100 "op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
101 std::vector<void *>{reinterpret_cast<void *>(1)}, true);
102 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info));
103 }
104
TEST_F(TestAscendGeRuntime,test_event_record_task_create_success)105 TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) {
106 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
107 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
108 {reinterpret_cast<rtEvent_t>(1)});
109 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 0);
110 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
111 ASSERT_TRUE(std::dynamic_pointer_cast<EventRecordTask>(task) != nullptr);
112 ASSERT_NO_THROW(task->Distribute());
113 }
114
TEST_F(TestAscendGeRuntime,test_event_record_task_create_invalid_event_id_failed)115 TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) {
116 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
117 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
118 {reinterpret_cast<rtEvent_t>(1)});
119 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 10);
120 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
121 }
122
TEST_F(TestAscendGeRuntime,test_event_wait_task_create_success)123 TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) {
124 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
125 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
126 {reinterpret_cast<rtEvent_t>(1)});
127 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 0);
128 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
129 ASSERT_TRUE(std::dynamic_pointer_cast<EventWaitTask>(task) != nullptr);
130 ASSERT_NO_THROW(task->Distribute());
131 }
132
TEST_F(TestAscendGeRuntime,test_event_wait_task_create_invalid_event_id_failed)133 TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) {
134 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
135 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
136 {reinterpret_cast<rtEvent_t>(1)});
137 std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 10);
138 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
139 }
140
TEST_F(TestAscendGeRuntime,test_hccl_task_create_success)141 TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) {
142 MockOpsKernelInfoStore ops_kernel_info_store;
143 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
144 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
145 {reinterpret_cast<rtEvent_t>(1)});
146 std::shared_ptr<TaskInfo> hccl_task_info = std::make_shared<HcclTaskInfo>(
147 "op_name", 0, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4,
148 5, std::vector<uint8_t>(6, 7), reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
149 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, hccl_task_info);
150 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task) != nullptr);
151 ASSERT_NO_THROW(task->Distribute());
152 }
153
TEST_F(TestAscendGeRuntime,test_hccl_task_create_stream_reuse_success)154 TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) {
155 const rtModel_t model = reinterpret_cast<rtModel_t>(0x12345678);
156 const rtStream_t stream = reinterpret_cast<rtStream_t>(0x87654321);
157 constexpr uint32_t stream_id = 0;
158 constexpr int64_t task1_stream_num = 3;
159 constexpr int64_t task2_stream_num = 5;
160 constexpr int64_t task3_stream_num = 4;
161 MockOpsKernelInfoStore ops_kernel_info_store;
162 ModelContext model_context(0, 0, 0, model, reinterpret_cast<rtStream_t>(2), {stream},
163 {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
164 std::shared_ptr<TaskInfo> hccl_task_info_1 = std::make_shared<HcclTaskInfo>(
165 "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
166 reinterpret_cast<void *>(3), 4, task1_stream_num, std::vector<uint8_t>(6, 7),
167 reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
168 std::shared_ptr<TaskInfo> hccl_task_info_2 = std::make_shared<HcclTaskInfo>(
169 "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
170 reinterpret_cast<void *>(3), 4, task2_stream_num, std::vector<uint8_t>(6, 7),
171 reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
172 std::shared_ptr<TaskInfo> hccl_task_info_3 = std::make_shared<HcclTaskInfo>(
173 "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
174 reinterpret_cast<void *>(3), 4, task3_stream_num, std::vector<uint8_t>(6, 7),
175 reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
176 std::shared_ptr<Task> task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1);
177 std::shared_ptr<Task> task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2);
178 std::shared_ptr<Task> task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3);
179 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_1) != nullptr);
180 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_2) != nullptr);
181 ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_3) != nullptr);
182 ASSERT_NO_THROW(task_1->Distribute());
183 ASSERT_NO_THROW(task_2->Distribute());
184 ASSERT_NO_THROW(task_3->Distribute());
185 {
186 std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
187 auto model_iter = HcclTask::model_stream_mapping_.find(model);
188 ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end());
189 auto stream_iter = model_iter->second.find(stream_id);
190 ASSERT_NE(stream_iter, model_iter->second.end());
191 const auto &stream_vec = stream_iter->second;
192 ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num)));
193 for (const auto &s : stream_vec) {
194 auto shared = s.lock();
195 ASSERT_TRUE(shared != nullptr);
196 }
197 }
198 }
199
TEST_F(TestAscendGeRuntime,test_label_goto_task_create_success)200 TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) {
201 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
202 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
203 {reinterpret_cast<rtEvent_t>(1)});
204 std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
205 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
206 auto label_goto_task = std::dynamic_pointer_cast<LabelGotoTask>(task);
207 ASSERT_TRUE(label_goto_task != nullptr);
208 ASSERT_NO_THROW(task->Distribute());
209 label_goto_task->index_value_ = new uint8_t[5];
210 }
211
TEST_F(TestAscendGeRuntime,test_label_goto_task_create_invalid_label_id_failed)212 TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) {
213 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
214 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
215 {reinterpret_cast<rtEvent_t>(1)});
216 std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
217 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info));
218 }
219
TEST_F(TestAscendGeRuntime,test_label_goto_task_reuse_success)220 TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) {
221 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
222 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
223 {reinterpret_cast<rtEvent_t>(1)});
224 std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
225 std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
226 std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
227 auto label_goto_task_1 = std::dynamic_pointer_cast<LabelGotoTask>(task1);
228 auto label_goto_task_2 = std::dynamic_pointer_cast<LabelGotoTask>(task2);
229 ASSERT_TRUE(label_goto_task_1 != nullptr);
230 ASSERT_NO_THROW(task1->Distribute());
231 ASSERT_TRUE(label_goto_task_2 != nullptr);
232 ASSERT_NO_THROW(task2->Distribute());
233 ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_);
234 }
235
TEST_F(TestAscendGeRuntime,test_label_set_task_create_success)236 TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) {
237 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
238 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
239 {reinterpret_cast<rtEvent_t>(1)});
240 std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelSetTaskInfo>("op_name", 0, 0);
241 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_set_task_info);
242 ASSERT_TRUE(std::dynamic_pointer_cast<LabelSetTask>(task) != nullptr);
243 ASSERT_NO_THROW(task->Distribute());
244 }
245
TEST_F(TestAscendGeRuntime,test_label_set_task_create_invalid_label_id_failed)246 TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) {
247 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
248 {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
249 {reinterpret_cast<rtEvent_t>(1)});
250 std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
251 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info));
252 }
253
TEST_F(TestAscendGeRuntime,test_label_switch_task_create_success)254 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) {
255 ModelContext model_context(
256 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
257 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
258 std::shared_ptr<TaskInfo> label_switch_task_info =
259 std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
260 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
261 ASSERT_TRUE(std::dynamic_pointer_cast<LabelSwitchTask>(task) != nullptr);
262 ASSERT_NO_THROW(task->Distribute());
263 }
264
TEST_F(TestAscendGeRuntime,test_label_switch_task_create_invalid_stream_id_failed)265 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) {
266 ModelContext model_context(
267 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
268 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
269 std::shared_ptr<TaskInfo> label_switch_task_info =
270 std::make_shared<LabelSwitchTaskInfo>("op_name", 1, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
271 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
272 }
273
TEST_F(TestAscendGeRuntime,test_label_switch_task_create_invalid_label_id_failed)274 TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) {
275 ModelContext model_context(
276 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
277 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
278 std::shared_ptr<TaskInfo> label_switch_task_info =
279 std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 3, std::vector<uint32_t>{0, 1, 2}, reinterpret_cast<void *>(1));
280 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
281 }
282
TEST_F(TestAscendGeRuntime,test_label_switch_task_reuse_success)283 TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) {
284 ModelContext model_context(
285 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
286 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
287 std::shared_ptr<TaskInfo> label_switch_task_info =
288 std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
289 std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
290 std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
291 auto label_switch_task_1 = std::dynamic_pointer_cast<LabelSwitchTask>(task1);
292 auto label_switch_task_2 = std::dynamic_pointer_cast<LabelSwitchTask>(task2);
293 ASSERT_TRUE(label_switch_task_1 != nullptr);
294 ASSERT_TRUE(label_switch_task_2 != nullptr);
295 ASSERT_NO_THROW(task1->Distribute());
296 ASSERT_NO_THROW(task2->Distribute());
297 ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_);
298 }
299
TEST_F(TestAscendGeRuntime,test_memcpy_async_task_create_success)300 TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) {
301 ModelContext model_context(
302 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
303 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
304 std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
305 "op_name", 0, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
306 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info);
307 ASSERT_TRUE(std::dynamic_pointer_cast<MemcpyAsyncTask>(task) != nullptr);
308 ASSERT_NO_THROW(task->Distribute());
309 }
310
TEST_F(TestAscendGeRuntime,test_memcpy_async_task_create_invalid_stream_id_failed)311 TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) {
312 ModelContext model_context(
313 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
314 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
315 std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
316 "op_name", 1, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
317 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info));
318 }
319
TEST_F(TestAscendGeRuntime,test_profiler_task_create_success)320 TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) {
321 ModelContext model_context(
322 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
323 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
324 std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 0, 1, true, 2);
325 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, profiler_task_info);
326 ASSERT_TRUE(std::dynamic_pointer_cast<ProfilerTask>(task) != nullptr);
327 ASSERT_NO_THROW(task->Distribute());
328 }
329
TEST_F(TestAscendGeRuntime,test_profiler_task_create_invalid_stream_id_failed)330 TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) {
331 ModelContext model_context(
332 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
333 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
334 std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 1, 1, true, 2);
335 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info));
336 }
337
TEST_F(TestAscendGeRuntime,test_stream_active_task_create_success)338 TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) {
339 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
340 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
341 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
342 {reinterpret_cast<rtEvent_t>(1)});
343 std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 1);
344 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info);
345 ASSERT_TRUE(std::dynamic_pointer_cast<StreamActiveTask>(task) != nullptr);
346 ASSERT_NO_THROW(task->Distribute());
347 }
348
TEST_F(TestAscendGeRuntime,test_stream_active_task_create_invalid_active_stream_id_failed)349 TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) {
350 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
351 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
352 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
353 {reinterpret_cast<rtEvent_t>(1)});
354 std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 2);
355 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info));
356 }
357
TEST_F(TestAscendGeRuntime,test_stream_switch_task_create_success)358 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) {
359 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
360 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
361 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
362 {reinterpret_cast<rtEvent_t>(1)});
363 std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
364 "op_name", 0, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
365 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
366 ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
367 ASSERT_NO_THROW(task->Distribute());
368 }
369
TEST_F(TestAscendGeRuntime,test_stream_switch_task_create_invalid_true_stream_id_failed)370 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) {
371 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
372 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
373 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
374 {reinterpret_cast<rtEvent_t>(1)});
375 std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
376 "op_name", 0, 2, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
377 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
378 ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
379 ASSERT_ANY_THROW(task->Distribute());
380 }
381
TEST_F(TestAscendGeRuntime,test_stream_switch_task_create_invalid_stream_id_failed)382 TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) {
383 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
384 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
385 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
386 {reinterpret_cast<rtEvent_t>(1)});
387 std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
388 "op_name", 2, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
389 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info));
390 }
391
TEST_F(TestAscendGeRuntime,test_tbe_task_create_success)392 TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) {
393 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
394 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
395 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
396 {reinterpret_cast<rtEvent_t>(1)});
397 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
398 "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
399 reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
400 std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
401 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
402 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
403 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
404 auto tbe_task = std::dynamic_pointer_cast<TbeTask>(task);
405 ASSERT_TRUE(tbe_task != nullptr);
406 ASSERT_NO_THROW(task->Distribute());
407 tbe_task->args_ = new uint8_t[5];
408 }
409
TEST_F(TestAscendGeRuntime,test_tbe_task_create_invalid_stream_id_failed)410 TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) {
411 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
412 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
413 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
414 {reinterpret_cast<rtEvent_t>(1)});
415 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
416 "op_name", 3, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
417 reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
418 std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
419 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
420 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
421 ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info));
422 }
423
TEST_F(TestAscendGeRuntime,test_tbe_task_create_empty_stub_func_failed)424 TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) {
425 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
426 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
427 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
428 {reinterpret_cast<rtEvent_t>(1)});
429 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
430 "op_name", 0, "", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, reinterpret_cast<void *>(7), 8,
431 std::vector<uint8_t>{9, 10}, std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
432 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
433 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
434 std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
435 ASSERT_TRUE(std::dynamic_pointer_cast<TbeTask>(task) != nullptr);
436 ASSERT_ANY_THROW(task->Distribute());
437 }
438
TEST_F(TestAscendGeRuntime,test_model_runner_success)439 TEST_F(TestAscendGeRuntime, test_model_runner_success) {
440 constexpr uint32_t model_id = 0;
441 ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
442 {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
443 {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
444 {reinterpret_cast<rtEvent_t>(1)});
445 std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
446 "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
447 reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
448 std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
449 std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
450 std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
451 std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
452 "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
453 std::vector<void *>{reinterpret_cast<void *>(1)}, true);
454 auto davice_model =
455 std::make_shared<DavinciModel>(std::vector<std::shared_ptr<TaskInfo>>{tbe_task_info, aicpu_task_info},
456 std::vector<uint32_t>{}, std::vector<uint32_t>{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0);
457 ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model));
458 auto iter = ModelRunner::Instance().runtime_models_.find(model_id);
459 ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end());
460 auto &task_list = iter->second->task_list_;
461 task_list.clear();
462 ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info)));
463 ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)));
464 ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id));
465 ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id));
466 ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id));
467 ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty());
468 ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty());
469 ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty());
470 ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id));
471 ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id));
472 }
473 } // namespace mindspore
474