1 //===--- opencl_test.cpp - Tests for OpenCL and the Acxxel API ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "acxxel.h"
10 #include "gtest/gtest.h"
11
12 #include <array>
13 #include <cstring>
14
15 namespace {
16
17 static const char *SaxpyKernelSource = R"(
18 __kernel void saxpyKernel(float A, __global float *X, __global float *Y, int N) {
19 int I = get_global_id(0);
20 if (I < N)
21 X[I] = A * X[I] + Y[I];
22 }
23 )";
24
TEST(OpenCL,Saxpy)25 TEST(OpenCL, Saxpy) {
26 constexpr size_t Length = 3;
27
28 float A = 2.f;
29 std::array<float, Length> X = {{0.f, 1.f, 2.f}};
30 std::array<float, Length> Y = {{3.f, 4.f, 5.f}};
31 std::array<float, Length> Expected = {{3.f, 6.f, 9.f}};
32
33 acxxel::Platform *OpenCL = acxxel::getOpenCLPlatform().getValue();
34 acxxel::Stream Stream = OpenCL->createStream().takeValue();
35 auto DeviceX = OpenCL->mallocD<float>(Length).takeValue();
36 auto DeviceY = OpenCL->mallocD<float>(Length).takeValue();
37 Stream.syncCopyHToD(X, DeviceX);
38 Stream.syncCopyHToD(Y, DeviceY);
39 acxxel::Program Program =
40 OpenCL
41 ->createProgramFromSource(acxxel::Span<const char>(
42 SaxpyKernelSource, std::strlen(SaxpyKernelSource)))
43 .takeValue();
44 acxxel::Kernel Kernel = Program.createKernel("saxpyKernel").takeValue();
45 float *RawX = static_cast<float *>(DeviceX);
46 float *RawY = static_cast<float *>(DeviceY);
47 int IntLength = Length;
48 void *Arguments[] = {&A, &RawX, &RawY, &IntLength};
49 size_t ArgumentSizes[] = {sizeof(float), sizeof(float *), sizeof(float *),
50 sizeof(int)};
51 EXPECT_FALSE(
52 Stream.asyncKernelLaunch(Kernel, Length, Arguments, ArgumentSizes)
53 .takeStatus()
54 .isError());
55 Stream.syncCopyDToH(DeviceX, X);
56 EXPECT_FALSE(Stream.sync().isError());
57
58 EXPECT_EQ(X, Expected);
59 }
60
61 } // namespace
62