1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <gtest/gtest.h>
10
11 #include "softmax-operator-tester.h"
12
13
TEST(SOFTMAX_NC_Q8,single_class)14 TEST(SOFTMAX_NC_Q8, single_class) {
15 SoftMaxOperatorTester()
16 .batch_size(1)
17 .channels(1)
18 .iterations(100)
19 .TestQ8();
20 }
21
TEST(SOFTMAX_NC_Q8,two_classes)22 TEST(SOFTMAX_NC_Q8, two_classes) {
23 SoftMaxOperatorTester()
24 .batch_size(1)
25 .channels(2)
26 .iterations(100)
27 .TestQ8();
28 }
29
TEST(SOFTMAX_NC_Q8,many_classes)30 TEST(SOFTMAX_NC_Q8, many_classes) {
31 for (size_t channels = 3; channels < 100; channels++) {
32 SoftMaxOperatorTester()
33 .batch_size(1)
34 .channels(channels)
35 .iterations(1)
36 .TestQ8();
37 }
38 }
39
TEST(SOFTMAX_NC_Q8,cifar_classes)40 TEST(SOFTMAX_NC_Q8, cifar_classes) {
41 // CIFAR-10
42 SoftMaxOperatorTester()
43 .batch_size(1)
44 .channels(10)
45 .iterations(15)
46 .TestQ8();
47 // CIFAR-100
48 SoftMaxOperatorTester()
49 .batch_size(1)
50 .channels(100)
51 .iterations(15)
52 .TestQ8();
53 }
54
TEST(SOFTMAX_NC_Q8,imagenet_classes)55 TEST(SOFTMAX_NC_Q8, imagenet_classes) {
56 // ImageNet-1K
57 SoftMaxOperatorTester()
58 .batch_size(1)
59 .channels(1000)
60 .iterations(10)
61 .TestQ8();
62 // ImageNet-1K+1
63 SoftMaxOperatorTester()
64 .batch_size(1)
65 .channels(1001)
66 .iterations(10)
67 .TestQ8();
68 // ImageNet-22K
69 SoftMaxOperatorTester()
70 .batch_size(1)
71 .channels(21841)
72 .iterations(10)
73 .TestQ8();
74 }
75
TEST(SOFTMAX_NC_Q8,many_channels_with_input_scale)76 TEST(SOFTMAX_NC_Q8, many_channels_with_input_scale) {
77 for (size_t channels = 1; channels < 100; channels += 5) {
78 for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 3.14159265f) {
79 SoftMaxOperatorTester()
80 .batch_size(1)
81 .channels(channels)
82 .input_scale(input_scale)
83 .iterations(1)
84 .TestQ8();
85 }
86 }
87 }
88
TEST(SOFTMAX_NC_Q8,many_channels_with_input_zero_point)89 TEST(SOFTMAX_NC_Q8, many_channels_with_input_zero_point) {
90 for (size_t channels = 1; channels < 100; channels += 5) {
91 for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
92 SoftMaxOperatorTester()
93 .batch_size(1)
94 .channels(channels)
95 .input_zero_point(uint8_t(input_zero_point))
96 .iterations(1)
97 .TestQ8();
98 }
99 }
100 }
101
TEST(SOFTMAX_NC_Q8,small_batch)102 TEST(SOFTMAX_NC_Q8, small_batch) {
103 for (size_t channels = 1; channels < 100; channels += 5) {
104 SoftMaxOperatorTester()
105 .batch_size(3)
106 .channels(channels)
107 .iterations(3)
108 .TestQ8();
109 }
110 }
111
TEST(SOFTMAX_NC_Q8,small_batch_with_input_stride)112 TEST(SOFTMAX_NC_Q8, small_batch_with_input_stride) {
113 for (size_t channels = 1; channels < 100; channels += 5) {
114 SoftMaxOperatorTester()
115 .batch_size(3)
116 .channels(channels)
117 .input_stride(129)
118 .iterations(3)
119 .TestQ8();
120 }
121 }
122
TEST(SOFTMAX_NC_Q8,small_batch_with_output_stride)123 TEST(SOFTMAX_NC_Q8, small_batch_with_output_stride) {
124 for (size_t channels = 1; channels < 100; channels += 5) {
125 SoftMaxOperatorTester()
126 .batch_size(3)
127 .channels(channels)
128 .output_stride(117)
129 .iterations(3)
130 .TestQ8();
131 }
132 }
133
TEST(SOFTMAX_NC_Q8,strided_batch_with_input_and_output_stride)134 TEST(SOFTMAX_NC_Q8, strided_batch_with_input_and_output_stride) {
135 for (size_t channels = 1; channels < 100; channels += 5) {
136 SoftMaxOperatorTester()
137 .batch_size(3)
138 .channels(channels)
139 .input_stride(129)
140 .output_stride(117)
141 .iterations(3)
142 .TestQ8();
143 }
144 }
145
TEST(SOFTMAX_NC_F32,single_class)146 TEST(SOFTMAX_NC_F32, single_class) {
147 SoftMaxOperatorTester()
148 .batch_size(1)
149 .channels(1)
150 .iterations(100)
151 .TestF32();
152 }
153
TEST(SOFTMAX_NC_F32,two_classes)154 TEST(SOFTMAX_NC_F32, two_classes) {
155 SoftMaxOperatorTester()
156 .batch_size(1)
157 .channels(2)
158 .iterations(100)
159 .TestF32();
160 }
161
TEST(SOFTMAX_NC_F32,many_classes)162 TEST(SOFTMAX_NC_F32, many_classes) {
163 for (size_t channels = 3; channels < 100; channels++) {
164 SoftMaxOperatorTester()
165 .batch_size(1)
166 .channels(channels)
167 .iterations(1)
168 .TestF32();
169 }
170 }
171
TEST(SOFTMAX_NC_F32,cifar_classes)172 TEST(SOFTMAX_NC_F32, cifar_classes) {
173 // CIFAR-10
174 SoftMaxOperatorTester()
175 .batch_size(1)
176 .channels(10)
177 .iterations(15)
178 .TestF32();
179 // CIFAR-100
180 SoftMaxOperatorTester()
181 .batch_size(1)
182 .channels(100)
183 .iterations(15)
184 .TestF32();
185 }
186
TEST(SOFTMAX_NC_F32,imagenet_classes)187 TEST(SOFTMAX_NC_F32, imagenet_classes) {
188 // ImageNet-1K
189 SoftMaxOperatorTester()
190 .batch_size(1)
191 .channels(1000)
192 .iterations(10)
193 .TestF32();
194 // ImageNet-1K+1
195 SoftMaxOperatorTester()
196 .batch_size(1)
197 .channels(1001)
198 .iterations(10)
199 .TestF32();
200 // ImageNet-22K
201 SoftMaxOperatorTester()
202 .batch_size(1)
203 .channels(21841)
204 .iterations(10)
205 .TestF32();
206 }
207
TEST(SOFTMAX_NC_F32,small_batch)208 TEST(SOFTMAX_NC_F32, small_batch) {
209 for (size_t channels = 1; channels < 100; channels += 5) {
210 SoftMaxOperatorTester()
211 .batch_size(3)
212 .channels(channels)
213 .iterations(3)
214 .TestF32();
215 }
216 }
217
TEST(SOFTMAX_NC_F32,small_batch_with_input_stride)218 TEST(SOFTMAX_NC_F32, small_batch_with_input_stride) {
219 for (size_t channels = 1; channels < 100; channels += 5) {
220 SoftMaxOperatorTester()
221 .batch_size(3)
222 .channels(channels)
223 .input_stride(129)
224 .iterations(3)
225 .TestF32();
226 }
227 }
228
TEST(SOFTMAX_NC_F32,small_batch_with_output_stride)229 TEST(SOFTMAX_NC_F32, small_batch_with_output_stride) {
230 for (size_t channels = 1; channels < 100; channels += 5) {
231 SoftMaxOperatorTester()
232 .batch_size(3)
233 .channels(channels)
234 .output_stride(117)
235 .iterations(3)
236 .TestF32();
237 }
238 }
239
TEST(SOFTMAX_NC_F32,strided_batch_with_input_and_output_stride)240 TEST(SOFTMAX_NC_F32, strided_batch_with_input_and_output_stride) {
241 for (size_t channels = 1; channels < 100; channels += 5) {
242 SoftMaxOperatorTester()
243 .batch_size(3)
244 .channels(channels)
245 .input_stride(129)
246 .output_stride(117)
247 .iterations(3)
248 .TestF32();
249 }
250 }
251