• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_SCALE_FIXTURE
25 #define ARM_COMPUTE_TEST_SCALE_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "tests/Globals.h"
30 #include "tests/Utils.h"
31 #include "tests/framework/Fixture.h"
32 
33 namespace arm_compute
34 {
35 namespace test
36 {
37 namespace benchmark
38 {
39 template <typename TensorType, typename Function, typename Accessor>
40 class ScaleFixture : public framework::Fixture
41 {
42 public:
43     template <typename...>
setup(TensorShape shape,DataType data_type,DataLayout data_layout,InterpolationPolicy policy,BorderMode border_mode,SamplingPolicy sampling_policy)44     void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
45     {
46         constexpr float max_width  = 8192.0f;
47         constexpr float max_height = 6384.0f;
48 
49         // Change shape in case of NHWC.
50         if(data_layout == DataLayout::NHWC)
51         {
52             permute(shape, PermutationVector(2U, 0U, 1U));
53         }
54 
55         std::mt19937                          generator(library->seed());
56         std::uniform_real_distribution<float> distribution_float(0.25f, 3.0f);
57         float                                 scale_x = distribution_float(generator);
58         float                                 scale_y = distribution_float(generator);
59 
60         scale_x = ((shape.x() * scale_x) > max_width) ? (max_width / shape.x()) : scale_x;
61         scale_y = ((shape.y() * scale_y) > max_height) ? (max_height / shape.y()) : scale_y;
62 
63         std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
64         uint8_t                                constant_border_value = static_cast<uint8_t>(distribution_u8(generator));
65 
66         const int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
67         const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
68 
69         TensorShape shape_scaled(shape);
70         shape_scaled.set(idx_width, shape[idx_width] * scale_x);
71         shape_scaled.set(idx_height, shape[idx_height] * scale_y);
72 
73         // Create tensors
74         src = create_tensor<TensorType>(shape, data_type);
75         dst = create_tensor<TensorType>(shape_scaled, data_type);
76 
77         // Create and configure function
78         scale_func.configure(&src, &dst, ScaleKernelInfo{ policy, border_mode, constant_border_value, sampling_policy, false });
79 
80         // Allocate tensors
81         src.allocator()->allocate();
82         dst.allocator()->allocate();
83     }
84 
run()85     void run()
86     {
87         scale_func.run();
88     }
89 
sync()90     void sync()
91     {
92         sync_if_necessary<TensorType>();
93         sync_tensor_if_necessary<TensorType>(dst);
94     }
95 
teardown()96     void teardown()
97     {
98         src.allocator()->free();
99         dst.allocator()->free();
100     }
101 
102 private:
103     TensorType src{};
104     TensorType dst{};
105     Function   scale_func{};
106 };
107 } // namespace benchmark
108 } // namespace test
109 } // namespace arm_compute
110 #endif /* ARM_COMPUTE_TEST_SCALE_FIXTURE */
111