1 #include "acxxel.h"
2 #include "config.h"
3 #include "gtest/gtest.h"
4
5 namespace {
6
7 using PlatformGetter = acxxel::Expected<acxxel::Platform *> (*)();
8 class MultiDeviceTest : public ::testing::TestWithParam<PlatformGetter> {};
9
TEST_P(MultiDeviceTest,AsyncCopy)10 TEST_P(MultiDeviceTest, AsyncCopy) {
11 acxxel::Platform *Platform = GetParam()().takeValue();
12 int DeviceCount = Platform->getDeviceCount().getValue();
13 EXPECT_GT(DeviceCount, 0);
14
15 int Length = 3;
16 auto A = std::unique_ptr<int[]>(new int[Length]);
17 auto B0 = std::unique_ptr<int[]>(new int[Length]);
18 auto B1 = std::unique_ptr<int[]>(new int[Length]);
19
20 auto ASpan = acxxel::Span<int>(A.get(), Length);
21 auto B0Span = acxxel::Span<int>(B0.get(), Length);
22 auto B1Span = acxxel::Span<int>(B1.get(), Length);
23
24 for (int I = 0; I < Length; ++I)
25 A[I] = I;
26
27 auto AsyncA = Platform->registerHostMem(ASpan).takeValue();
28 auto AsyncB0 = Platform->registerHostMem(B0Span).takeValue();
29 auto AsyncB1 = Platform->registerHostMem(B1Span).takeValue();
30
31 acxxel::Stream Stream0 = Platform->createStream(0).takeValue();
32 acxxel::Stream Stream1 = Platform->createStream(1).takeValue();
33 auto Device0 = Platform->mallocD<int>(Length, 0).takeValue();
34 auto Device1 = Platform->mallocD<int>(Length, 1).takeValue();
35
36 EXPECT_FALSE(Stream0.asyncCopyHToD(AsyncA, Device0, Length)
37 .asyncCopyDToH(Device0, AsyncB0, Length)
38 .sync()
39 .isError());
40
41 EXPECT_FALSE(Stream1.asyncCopyHToD(AsyncA, Device1, Length)
42 .asyncCopyDToH(Device1, AsyncB1, Length)
43 .sync()
44 .isError());
45
46 for (int I = 0; I < Length; ++I) {
47 EXPECT_EQ(B0[I], I);
48 EXPECT_EQ(B1[I], I);
49 }
50 }
51
TEST_P(MultiDeviceTest,Events)52 TEST_P(MultiDeviceTest, Events) {
53 acxxel::Platform *Platform = GetParam()().takeValue();
54 int DeviceCount = Platform->getDeviceCount().getValue();
55 EXPECT_GT(DeviceCount, 0);
56
57 acxxel::Stream Stream0 = Platform->createStream(0).takeValue();
58 acxxel::Stream Stream1 = Platform->createStream(1).takeValue();
59 acxxel::Event Event0 = Platform->createEvent(0).takeValue();
60 acxxel::Event Event1 = Platform->createEvent(1).takeValue();
61
62 EXPECT_FALSE(Stream0.enqueueEvent(Event0).sync().isError());
63 EXPECT_FALSE(Stream1.enqueueEvent(Event1).sync().isError());
64
65 EXPECT_TRUE(Event0.isDone());
66 EXPECT_TRUE(Event1.isDone());
67
68 EXPECT_FALSE(Event0.sync().isError());
69 EXPECT_FALSE(Event1.sync().isError());
70 }
71
72 #if defined(ACXXEL_ENABLE_CUDA) || defined(ACXXEL_ENABLE_OPENCL)
73 INSTANTIATE_TEST_CASE_P(BothPlatformTest, MultiDeviceTest,
74 ::testing::Values(
75 #ifdef ACXXEL_ENABLE_CUDA
76 acxxel::getCUDAPlatform
77 #ifdef ACXXEL_ENABLE_OPENCL
78 ,
79 #endif
80 #endif
81 #ifdef ACXXEL_ENABLE_OPENCL
82 acxxel::getOpenCLPlatform
83 #endif
84 ));
85 #endif
86
87 } // namespace
88