1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7
7
8 #include "ClWorkloadFactoryHelper.hpp"
9
10 #include <test/TensorHelpers.hpp>
11
12 #include <backendsCommon/CpuTensorHandle.hpp>
13 #include <backendsCommon/WorkloadFactory.hpp>
14
15 #include <cl/ClContextControl.hpp>
16 #include <cl/ClWorkloadFactory.hpp>
17 #include <cl/OpenClTimer.hpp>
18
19 #include <backendsCommon/test/TensorCopyUtils.hpp>
20 #include <backendsCommon/test/WorkloadTestUtils.hpp>
21
22 #include <arm_compute/runtime/CL/CLScheduler.h>
23
24 #include <boost/test/unit_test.hpp>
25
26 #include <iostream>
27
28 using namespace armnn;
29
30 struct OpenClFixture
31 {
32 // Initialising ClContextControl to ensure OpenCL is loaded correctly for each test case.
33 // NOTE: Profiling needs to be enabled in ClContextControl to be able to obtain execution
34 // times from OpenClTimer.
OpenClFixtureOpenClFixture35 OpenClFixture() : m_ClContextControl(nullptr, true) {}
~OpenClFixtureOpenClFixture36 ~OpenClFixture() {}
37
38 ClContextControl m_ClContextControl;
39 };
40
41 BOOST_FIXTURE_TEST_SUITE(OpenClTimerBatchNorm, OpenClFixture)
42 using FactoryType = ClWorkloadFactory;
43
BOOST_AUTO_TEST_CASE(OpenClTimerBatchNorm)44 BOOST_AUTO_TEST_CASE(OpenClTimerBatchNorm)
45 {
46 auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager();
47 ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(memoryManager);
48
49 const unsigned int width = 2;
50 const unsigned int height = 3;
51 const unsigned int channels = 2;
52 const unsigned int num = 1;
53
54 TensorInfo inputTensorInfo( {num, channels, height, width}, DataType::Float32);
55 TensorInfo outputTensorInfo({num, channels, height, width}, DataType::Float32);
56 TensorInfo tensorInfo({channels}, DataType::Float32);
57
58 auto input = MakeTensor<float, 4>(inputTensorInfo,
59 {
60 1.f, 4.f,
61 4.f, 2.f,
62 1.f, 6.f,
63
64 1.f, 1.f,
65 4.f, 1.f,
66 -2.f, 4.f
67 });
68
69 // these values are per-channel of the input
70 auto mean = MakeTensor<float, 1>(tensorInfo, { 3.f, -2.f });
71 auto variance = MakeTensor<float, 1>(tensorInfo, { 4.f, 9.f });
72 auto beta = MakeTensor<float, 1>(tensorInfo, { 3.f, 2.f });
73 auto gamma = MakeTensor<float, 1>(tensorInfo, { 2.f, 1.f });
74
75 ARMNN_NO_DEPRECATE_WARN_BEGIN
76 std::unique_ptr<ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
77 std::unique_ptr<ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
78 ARMNN_NO_DEPRECATE_WARN_END
79
80 BatchNormalizationQueueDescriptor data;
81 WorkloadInfo info;
82 ScopedCpuTensorHandle meanTensor(tensorInfo);
83 ScopedCpuTensorHandle varianceTensor(tensorInfo);
84 ScopedCpuTensorHandle betaTensor(tensorInfo);
85 ScopedCpuTensorHandle gammaTensor(tensorInfo);
86
87 AllocateAndCopyDataToITensorHandle(&meanTensor, &mean[0]);
88 AllocateAndCopyDataToITensorHandle(&varianceTensor, &variance[0]);
89 AllocateAndCopyDataToITensorHandle(&betaTensor, &beta[0]);
90 AllocateAndCopyDataToITensorHandle(&gammaTensor, &gamma[0]);
91
92 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
93 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
94 data.m_Mean = &meanTensor;
95 data.m_Variance = &varianceTensor;
96 data.m_Beta = &betaTensor;
97 data.m_Gamma = &gammaTensor;
98 data.m_Parameters.m_Eps = 0.0f;
99
100 // for each channel:
101 // substract mean, divide by standard deviation (with an epsilon to avoid div by 0)
102 // multiply by gamma and add beta
103 std::unique_ptr<IWorkload> workload = workloadFactory.CreateBatchNormalization(data, info);
104
105 inputHandle->Allocate();
106 outputHandle->Allocate();
107
108 CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
109
110 OpenClTimer openClTimer;
111
112 BOOST_CHECK_EQUAL(openClTimer.GetName(), "OpenClKernelTimer");
113
114 //Start the timer
115 openClTimer.Start();
116
117 //Execute the workload
118 workload->Execute();
119
120 //Stop the timer
121 openClTimer.Stop();
122
123 BOOST_CHECK_EQUAL(openClTimer.GetMeasurements().size(), 1);
124
125 BOOST_CHECK_EQUAL(openClTimer.GetMeasurements().front().m_Name,
126 "OpenClKernelTimer/0: batchnormalization_layer_nchw GWS[1,3,2]");
127
128 BOOST_CHECK(openClTimer.GetMeasurements().front().m_Value > 0);
129
130 }
131
132 BOOST_AUTO_TEST_SUITE_END()
133
134 #endif //aarch64 or x86_64
135