1 #include <gtest/gtest.h>
2
3 #include <ATen/xpu/XPUEvent.h>
4 #include <c10/util/irange.h>
5 #include <c10/xpu/test/impl/XPUTest.h>
6
TEST(XpuEventTest,testXPUEventBehavior)7 TEST(XpuEventTest, testXPUEventBehavior) {
8 if (!at::xpu::is_available()) {
9 return;
10 }
11 auto stream = c10::xpu::getStreamFromPool();
12 at::xpu::XPUEvent event;
13
14 EXPECT_TRUE(event.query());
15 EXPECT_TRUE(!event.isCreated());
16
17 event.recordOnce(stream);
18 EXPECT_TRUE(event.isCreated());
19
20 auto wait_stream0 = c10::xpu::getStreamFromPool();
21 auto wait_stream1 = c10::xpu::getStreamFromPool();
22
23 event.block(wait_stream0);
24 event.block(wait_stream1);
25
26 wait_stream0.synchronize();
27 EXPECT_TRUE(event.query());
28 }
29
TEST(XpuEventTest,testXPUEventCrossDevice)30 TEST(XpuEventTest, testXPUEventCrossDevice) {
31 if (at::xpu::device_count() <= 1) {
32 return;
33 }
34
35 const auto stream0 = at::xpu::getStreamFromPool();
36 at::xpu::XPUEvent event0;
37
38 const auto stream1 = at::xpu::getStreamFromPool(false, 1);
39 at::xpu::XPUEvent event1;
40
41 event0.record(stream0);
42 event1.record(stream1);
43
44 event0 = std::move(event1);
45
46 EXPECT_EQ(event0.device(), at::Device(at::kXPU, 1));
47
48 event0.block(stream0);
49
50 stream0.synchronize();
51 ASSERT_TRUE(event0.query());
52 }
53
eventSync(sycl::event & event)54 void eventSync(sycl::event& event) {
55 event.wait();
56 }
57
TEST(XpuEventTest,testXPUEventFunction)58 TEST(XpuEventTest, testXPUEventFunction) {
59 if (!at::xpu::is_available()) {
60 return;
61 }
62
63 constexpr int numel = 1024;
64 int hostData[numel];
65 initHostData(hostData, numel);
66
67 auto stream = c10::xpu::getStreamFromPool();
68 int* deviceData = sycl::malloc_device<int>(numel, stream);
69
70 // H2D
71 stream.queue().memcpy(deviceData, hostData, sizeof(int) * numel);
72 at::xpu::XPUEvent event;
73 event.record(stream);
74 // To validate the implicit conversion of an XPUEvent to sycl::event.
75 eventSync(event);
76 EXPECT_TRUE(event.query());
77
78 clearHostData(hostData, numel);
79
80 // D2H
81 stream.queue().memcpy(hostData, deviceData, sizeof(int) * numel);
82 event.record(stream);
83 event.synchronize();
84
85 validateHostData(hostData, numel);
86
87 clearHostData(hostData, numel);
88 // D2H
89 stream.queue().memcpy(hostData, deviceData, sizeof(int) * numel);
90 // The event has already been created, so there will be no recording of the
91 // stream via recordOnce() here.
92 event.recordOnce(stream);
93 EXPECT_TRUE(event.query());
94
95 stream.synchronize();
96 sycl::free(deviceData, c10::xpu::get_device_context());
97
98 if (at::xpu::device_count() <= 1) {
99 return;
100 }
101 c10::xpu::set_device(1);
102 auto stream1 = c10::xpu::getStreamFromPool();
103 ASSERT_THROW(event.record(stream1), c10::Error);
104 }
105