1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests of 2+D convolution with trivial kernels and no special variations (like
17 // strides and padding).
18
19 #include <memory>
20
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/global_data.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/padding.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace xla {
38 namespace {
39
40 class ConvolutionTest : public ClientLibraryTestBase {
41 protected:
42 #if XLA_TEST_BACKEND_GPU
43 // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
44 // convolution. So relax the absolute error threshold.
45 ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3);
46 #else
47 ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3);
48 #endif
49 };
50
51 using TestTypes = ::testing::Types<
52 // TODO(b/183565702): Support integer convs on GPU.
53 #if !XLA_TEST_BACKEND_GPU
54 int32_t,
55 #endif
56 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
57 Eigen::half,
58 #endif
59 float>;
60
61 template <typename T>
62 class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
63 public:
RunTest()64 void RunTest() {
65 const int kInputActivationSizeY = 3;
66 const int kInputActivationSizeX = 3;
67 const int kInputActivationSizeZ = 256;
68 const int kKernelSizeX = 2;
69 const int kKernelSizeY = 2;
70 const int kOutputActivationSizeZ = 256;
71 const int kMiniBatchSize = 4;
72 auto alhs = std::make_unique<Array4D<T>>(
73 kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY,
74 kInputActivationSizeX);
75 alhs->FillWithMultiples(static_cast<T>(static_cast<T>(1.0f)));
76 ASSERT_EQ(3, alhs->width());
77 ASSERT_EQ(3, alhs->height());
78
79 auto arhs = std::make_unique<Array4D<T>>(kOutputActivationSizeZ,
80 kInputActivationSizeZ,
81 kKernelSizeY, kKernelSizeX);
82 Array2D<T> rhs_raster({
83 {static_cast<T>(1.0f), static_cast<T>(0.0f)}, // row 0
84 {static_cast<T>(0.0f), static_cast<T>(0.0f)}, // row 1
85 });
86 arhs->FillWithYX(rhs_raster);
87 ASSERT_EQ(2, arhs->width());
88 ASSERT_EQ(2, arhs->height());
89
90 XlaBuilder builder(TestName());
91 auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
92 auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
93 PrecisionConfig precision;
94 // The left hand side of the convolution is numbers between 0 and 2304 which
95 // requires at least 11 mantissa bits and the DEFAULT precision config is
96 // allowed to round to bfloat16 which only has 7 mantissa bits.
97 precision.add_operand_precision(PrecisionConfig::HIGHEST);
98 precision.add_operand_precision(PrecisionConfig::DEFAULT);
99 Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
100 /*batch_group_count=*/1, &precision);
101
102 ComputeAndCompare(&builder, {}, error_spec_);
103 }
104 };
105
106 TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota,Types)107 XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
108 this->RunTest();
109 }
110
111 template <typename T>
112 class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
113 public:
RunTest()114 void RunTest() {
115 XlaBuilder builder(TestName());
116 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
117 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
118 auto input = Parameter(&builder, 0, input_shape, "input");
119 auto filter = Parameter(&builder, 1, filter_shape, "filter");
120 Conv(input, filter, {1, 1}, Padding::kValid);
121
122 Array4D<T> input_data(1, 1, 1, 2);
123 input_data.FillWithYX(Array2D<T>({
124 {static_cast<T>(1.0f), static_cast<T>(2.0f)},
125 }));
126 Array4D<T> filter_data(1, 1, 1, 2);
127 filter_data.FillWithYX(Array2D<T>({
128 {static_cast<T>(5.0f), static_cast<T>(6.0f)},
129 }));
130
131 ComputeAndCompare(&builder,
132 {LiteralUtil::CreateFromArray(input_data),
133 LiteralUtil::CreateFromArray(filter_data)},
134 error_spec_);
135 }
136 };
137
138 TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid,Types)139 TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
140
141 // Tests valid padding for 2D convolution in raster space.
142 template <typename T>
143 class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
144 public:
RunTest()145 void RunTest() {
146 XlaBuilder builder(TestName());
147 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
148 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
149 auto input = Parameter(&builder, 0, input_shape, "input");
150 auto filter = Parameter(&builder, 1, filter_shape, "filter");
151 Conv(input, filter, {1, 1}, Padding::kValid);
152
153 Array4D<T> input_data(1, 1, 4, 4);
154 input_data.FillWithYX(Array2D<T>({
155 {static_cast<T>(1.0f), static_cast<T>(2.0f), static_cast<T>(3.0f),
156 static_cast<T>(4.0f)},
157 {static_cast<T>(5.0f), static_cast<T>(6.0f), static_cast<T>(7.0f),
158 static_cast<T>(8.0f)},
159 {static_cast<T>(9.0f), static_cast<T>(10.0f), static_cast<T>(11.0f),
160 static_cast<T>(12.0f)},
161 {static_cast<T>(13.0f), static_cast<T>(14.0f), static_cast<T>(15.0f),
162 static_cast<T>(16.0f)},
163 }));
164 Array4D<T> filter_data(1, 1, 2, 2);
165 filter_data.FillWithYX(Array2D<T>({
166 {static_cast<T>(5.0f), static_cast<T>(6.0f)},
167 {static_cast<T>(7.0f), static_cast<T>(8.0f)},
168 }));
169 ComputeAndCompare(&builder,
170 {LiteralUtil::CreateFromArray(input_data),
171 LiteralUtil::CreateFromArray(filter_data)},
172 error_spec_);
173 }
174 };
175
176 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid,Types)177 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
178
179 // Tests same padding for 2D convolution in raster space.
180 template <typename T>
181 class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
182 public:
RunTest()183 void RunTest() {
184 XlaBuilder builder(TestName());
185 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
186 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
187 auto input = Parameter(&builder, 0, input_shape, "input");
188 auto filter = Parameter(&builder, 1, filter_shape, "filter");
189 Conv(input, filter, {1, 1}, Padding::kSame);
190
191 Array4D<T> input_data(1, 1, 4, 4);
192 input_data.FillWithYX(Array2D<T>({
193 {static_cast<T>(1.0f), static_cast<T>(2.0f), static_cast<T>(3.0f),
194 static_cast<T>(4.0f)},
195 {static_cast<T>(5.0f), static_cast<T>(6.0f), static_cast<T>(7.0f),
196 static_cast<T>(8.0f)},
197 {static_cast<T>(9.0f), static_cast<T>(10.0f), static_cast<T>(11.0f),
198 static_cast<T>(12.0f)},
199 {static_cast<T>(13.0f), static_cast<T>(14.0f), static_cast<T>(15.0f),
200 static_cast<T>(16.0f)},
201 }));
202 Array4D<T> filter_data(1, 1, 2, 2);
203 filter_data.FillWithYX(Array2D<T>({
204 {static_cast<T>(5.0f), static_cast<T>(6.0f)},
205 {static_cast<T>(7.0f), static_cast<T>(8.0f)},
206 }));
207
208 ComputeAndCompare(&builder,
209 {LiteralUtil::CreateFromArray(input_data),
210 LiteralUtil::CreateFromArray(filter_data)},
211 error_spec_);
212 }
213 };
214
215 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same,Types)216 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
217
218 // Tests same padding for 2D convolution in raster space with an odd sized
219 // kernel.
220 template <typename T>
221 class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
222 public:
RunTest()223 void RunTest() {
224 XlaBuilder builder(TestName());
225 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
226 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3});
227 auto input = Parameter(&builder, 0, input_shape, "input");
228 auto filter = Parameter(&builder, 1, filter_shape, "filter");
229 Conv(input, filter, {1, 1}, Padding::kSame);
230
231 Array4D<T> input_data(1, 1, 4, 4);
232 input_data.FillWithYX(
233 Array2D<T>({{static_cast<T>(1.0f), static_cast<T>(2.0f),
234 static_cast<T>(3.0f), static_cast<T>(4.0f)},
235 {static_cast<T>(5.0f), static_cast<T>(6.0f),
236 static_cast<T>(7.0f), static_cast<T>(8.0f)},
237 {static_cast<T>(9.0f), static_cast<T>(10.0f),
238 static_cast<T>(11.0f), static_cast<T>(12.0f)},
239 {static_cast<T>(13.0f), static_cast<T>(14.0f),
240 static_cast<T>(15.0f), static_cast<T>(16.0f)}}));
241 Array4D<T> filter_data(1, 1, 3, 3);
242 filter_data.FillWithYX(Array2D<T>(
243 {{static_cast<T>(5.0f), static_cast<T>(6.0f), static_cast<T>(7.0f)},
244 {static_cast<T>(8.0f), static_cast<T>(9.0f), static_cast<T>(10.0f)},
245 {static_cast<T>(11.0f), static_cast<T>(12.0f),
246 static_cast<T>(13.0f)}}));
247 // clang-format on
248 ComputeAndCompare(&builder,
249 {LiteralUtil::CreateFromArray(input_data),
250 LiteralUtil::CreateFromArray(filter_data)},
251 error_spec_);
252 }
253 };
254
255 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same,Types)256 TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
257
XLA_TEST_F(ConvolutionTest,Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)258 XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
259 XlaBuilder builder(TestName());
260 std::vector<int64_t> input_dims = {1, 4, 2, 3, 3};
261 std::vector<int64_t> filter_dims = {2, 2, 2, 3, 3};
262 Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
263 Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
264 {
265 auto input = Parameter(&builder, 0, input_shape, "input");
266 auto filter = Parameter(&builder, 1, filter_shape, "filter");
267
268 // Tensorflow dimension numbers for 3D convolution.
269 ConvolutionDimensionNumbers dnums;
270 dnums.set_input_batch_dimension(0);
271 dnums.set_output_batch_dimension(0);
272 dnums.add_input_spatial_dimensions(1);
273 dnums.add_output_spatial_dimensions(1);
274 dnums.add_input_spatial_dimensions(2);
275 dnums.add_output_spatial_dimensions(2);
276 dnums.add_input_spatial_dimensions(3);
277 dnums.add_output_spatial_dimensions(3);
278 dnums.set_input_feature_dimension(4);
279 dnums.set_output_feature_dimension(4);
280 dnums.add_kernel_spatial_dimensions(0);
281 dnums.add_kernel_spatial_dimensions(1);
282 dnums.add_kernel_spatial_dimensions(2);
283 dnums.set_kernel_input_feature_dimension(3);
284 dnums.set_kernel_output_feature_dimension(4);
285
286 ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums);
287 }
288
289 std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
290 iota(input_elems.begin(), input_elems.end(), 1.0f);
291 auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
292 auto input_r5 = input_r1.Reshape(input_dims).value();
293
294 std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
295 iota(filter_elems.begin(), filter_elems.end(), 1.0f);
296 auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
297 auto filter_r5 = filter_r1.Reshape(filter_dims).value();
298
299 auto expected_r1 = LiteralUtil::CreateR1<float>(
300 {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
301 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
302 auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).value();
303
304 auto input_literal = client_->TransferToServer(input_r5).value();
305 auto filter_literal = client_->TransferToServer(filter_r5).value();
306
307 ComputeAndCompareLiteral(&builder, expected_r5,
308 {input_literal.get(), filter_literal.get()},
309 error_spec_);
310 }
311
312 // std::iota doesn't work when init_value has a type Eigen::half in some build
313 // servers. The error message is missing the operator ++.
314 template <typename T>
iota_int_init_value(std::vector<T> & values,int init_value)315 void iota_int_init_value(std::vector<T>& values, int init_value) {
316 absl::c_for_each(values,
317 [&](T& value) { value = static_cast<T>(init_value++); });
318 }
319
320 template <typename T>
321 class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
322 public:
RunTest()323 void RunTest() {
324 XlaBuilder builder(TestName());
325 std::vector<int64_t> input_dims = {1, 3, 3, 5};
326 std::vector<int64_t> filter_dims = {3, 3, 5, 3};
327 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
328 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
329 {
330 auto input = Parameter(&builder, 0, input_shape, "input");
331 auto filter = Parameter(&builder, 1, filter_shape, "filter");
332
333 // Tensorflow dimension numbers for 2D convolution.
334 ConvolutionDimensionNumbers dnums;
335 dnums.set_input_batch_dimension(0);
336 dnums.set_output_batch_dimension(0);
337 dnums.add_input_spatial_dimensions(1);
338 dnums.add_output_spatial_dimensions(1);
339 dnums.add_input_spatial_dimensions(2);
340 dnums.add_output_spatial_dimensions(2);
341 dnums.set_input_feature_dimension(3);
342 dnums.set_output_feature_dimension(3);
343 dnums.add_kernel_spatial_dimensions(0);
344 dnums.add_kernel_spatial_dimensions(1);
345 dnums.set_kernel_input_feature_dimension(2);
346 dnums.set_kernel_output_feature_dimension(3);
347
348 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums);
349 }
350
351 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
352 iota_int_init_value(input_elems, 1);
353 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
354 auto input_r4 = input_r1.Reshape(input_dims).value();
355
356 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
357 iota_int_init_value(filter_elems, 1);
358 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
359 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
360
361 auto expected_r1 = LiteralUtil::CreateR1<T>(
362 {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
363 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).value();
364
365 auto input_literal = client_->TransferToServer(input_r4).value();
366 auto filter_literal = client_->TransferToServer(filter_r4).value();
367
368 ComputeAndCompareLiteral(&builder, expected_r4,
369 {input_literal.get(), filter_literal.get()},
370 error_spec_);
371 }
372 };
373
374 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid,Types)375 TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); }
376
377 template <typename T>
378 class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
379 public:
RunTest()380 void RunTest() {
381 XlaBuilder builder(TestName());
382 std::vector<int64_t> input_dims = {1, 3, 3, 5};
383 std::vector<int64_t> filter_dims = {3, 3, 1, 15};
384 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
385 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
386 {
387 auto input = Parameter(&builder, 0, input_shape, "input");
388 auto filter = Parameter(&builder, 1, filter_shape, "filter");
389
390 // Tensorflow dimension numbers for 2D convolution.
391 ConvolutionDimensionNumbers dnums;
392 dnums.set_input_batch_dimension(0);
393 dnums.set_output_batch_dimension(0);
394 dnums.add_input_spatial_dimensions(1);
395 dnums.add_output_spatial_dimensions(1);
396 dnums.add_input_spatial_dimensions(2);
397 dnums.add_output_spatial_dimensions(2);
398 dnums.set_input_feature_dimension(3);
399 dnums.set_output_feature_dimension(3);
400 dnums.add_kernel_spatial_dimensions(0);
401 dnums.add_kernel_spatial_dimensions(1);
402 dnums.set_kernel_input_feature_dimension(2);
403 dnums.set_kernel_output_feature_dimension(3);
404
405 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
406 /*feature_group_count=*/5);
407 }
408
409 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
410 iota_int_init_value(input_elems, 1);
411 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
412 auto input_r4 = input_r1.Reshape(input_dims).value();
413
414 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
415 iota_int_init_value(filter_elems, 1);
416 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
417 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
418
419 auto expected_r1 = LiteralUtil::CreateR1<T>(
420 {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
421 static_cast<T>(17172), static_cast<T>(17370), static_cast<T>(17568),
422 static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
423 static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
424 static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
425 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).value();
426
427 auto input_literal = client_->TransferToServer(input_r4).value();
428 auto filter_literal = client_->TransferToServer(filter_r4).value();
429
430 ComputeAndCompareLiteral(&builder, expected_r4,
431 {input_literal.get(), filter_literal.get()},
432 error_spec_);
433 }
434 };
435
436 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid,Types)437 TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) {
438 this->RunTest();
439 }
440
441 template <typename T>
442 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest {
443 public:
RunTest()444 void RunTest() {
445 XlaBuilder builder(TestName());
446 std::vector<int64_t> input_dims = {1, 4, 4, 5};
447 std::vector<int64_t> filter_dims = {3, 3, 1, 5};
448 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
449 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
450 {
451 auto input = Parameter(&builder, 0, input_shape, "input");
452 auto filter = Parameter(&builder, 1, filter_shape, "filter");
453
454 // Tensorflow dimension numbers for 2D convolution.
455 ConvolutionDimensionNumbers dnums;
456 dnums.set_input_batch_dimension(0);
457 dnums.set_output_batch_dimension(0);
458 dnums.add_input_spatial_dimensions(1);
459 dnums.add_output_spatial_dimensions(1);
460 dnums.add_input_spatial_dimensions(2);
461 dnums.add_output_spatial_dimensions(2);
462 dnums.set_input_feature_dimension(3);
463 dnums.set_output_feature_dimension(3);
464 dnums.add_kernel_spatial_dimensions(0);
465 dnums.add_kernel_spatial_dimensions(1);
466 dnums.set_kernel_input_feature_dimension(2);
467 dnums.set_kernel_output_feature_dimension(3);
468
469 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
470 /*feature_group_count=*/5);
471 }
472
473 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
474 iota_int_init_value(input_elems, 1);
475 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
476 auto input_r4 = input_r1.Reshape(input_dims).value();
477
478 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
479 iota_int_init_value(filter_elems, 1);
480 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
481 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
482
483 auto expected_r1 = LiteralUtil::CreateR1<T>(
484 {static_cast<T>(6864), static_cast<T>(7296), static_cast<T>(7746),
485 static_cast<T>(8214), static_cast<T>(8700), static_cast<T>(7809),
486 static_cast<T>(8286), static_cast<T>(8781), static_cast<T>(9294),
487 static_cast<T>(9825), static_cast<T>(10644), static_cast<T>(11256),
488 static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
489 static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
490 static_cast<T>(13614), static_cast<T>(14325)});
491 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).value();
492
493 auto input_literal = client_->TransferToServer(input_r4).value();
494 auto filter_literal = client_->TransferToServer(filter_r4).value();
495
496 ComputeAndCompareLiteral(&builder, expected_r4,
497 {input_literal.get(), filter_literal.get()},
498 error_spec_);
499
500 auto filter_r = filter_r1.Reshape(filter_dims);
501 }
502 };
503
504 TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid,Types)505 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) {
506 this->RunTest();
507 }
508
509 template <typename T>
510 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest {
511 public:
RunTest()512 void RunTest() {
513 XlaBuilder builder(TestName());
514 std::vector<int64_t> input_dims = {1, 4, 4, 512};
515 std::vector<int64_t> filter_dims = {3, 3, 1, 512};
516 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
517 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
518 {
519 auto input = Parameter(&builder, 0, input_shape, "input");
520 auto filter = Parameter(&builder, 1, filter_shape, "filter");
521
522 // Tensorflow dimension numbers for 2D convolution.
523 ConvolutionDimensionNumbers dnums;
524 dnums.set_input_batch_dimension(0);
525 dnums.set_output_batch_dimension(0);
526 dnums.add_input_spatial_dimensions(1);
527 dnums.add_output_spatial_dimensions(1);
528 dnums.add_input_spatial_dimensions(2);
529 dnums.add_output_spatial_dimensions(2);
530 dnums.set_input_feature_dimension(3);
531 dnums.set_output_feature_dimension(3);
532 dnums.add_kernel_spatial_dimensions(0);
533 dnums.add_kernel_spatial_dimensions(1);
534 dnums.set_kernel_input_feature_dimension(2);
535 dnums.set_kernel_output_feature_dimension(3);
536
537 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
538 /*feature_group_count=*/512);
539 }
540
541 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
542 static_cast<T>(1));
543 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
544 auto input_r4 = input_r1.Reshape(input_dims).value();
545
546 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
547 static_cast<T>(2));
548 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
549 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
550
551 std::vector<T> output_elems(2048, static_cast<T>(18));
552
553 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
554 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).value();
555
556 auto input_literal = client_->TransferToServer(input_r4).value();
557 auto filter_literal = client_->TransferToServer(filter_r4).value();
558
559 ComputeAndCompareLiteral(&builder, expected_r4,
560 {input_literal.get(), filter_literal.get()},
561 error_spec_);
562 }
563 };
564
565 TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid,Types)566 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) {
567 this->RunTest();
568 }
569
570 template <typename T>
571 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes
572 : public ConvolutionTest {
573 public:
RunTest()574 void RunTest() {
575 XlaBuilder builder(TestName());
576 std::vector<int64_t> input_dims = {1, 4, 4, 512};
577 std::vector<int64_t> filter_dims = {3, 3, 1, 512};
578 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
579 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
580 {
581 auto input = Parameter(&builder, 0, input_shape, "input");
582 auto filter = Parameter(&builder, 1, filter_shape, "filter");
583
584 // Tensorflow dimension numbers for 2D convolution.
585 ConvolutionDimensionNumbers dnums;
586 dnums.set_input_batch_dimension(0);
587 dnums.set_output_batch_dimension(0);
588 dnums.add_input_spatial_dimensions(1);
589 dnums.add_output_spatial_dimensions(1);
590 dnums.add_input_spatial_dimensions(2);
591 dnums.add_output_spatial_dimensions(2);
592 dnums.set_input_feature_dimension(3);
593 dnums.set_output_feature_dimension(3);
594 dnums.add_kernel_spatial_dimensions(0);
595 dnums.add_kernel_spatial_dimensions(1);
596 dnums.set_kernel_input_feature_dimension(2);
597 dnums.set_kernel_output_feature_dimension(3);
598
599 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
600 /*feature_group_count=*/512);
601 }
602
603 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
604 static_cast<T>(1));
605 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
606 auto input_r4 = input_r1.Reshape(input_dims).value();
607
608 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
609 static_cast<T>(2));
610 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
611 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
612
613 std::vector<T> output_elems(2048, static_cast<T>(18));
614
615 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
616 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).value();
617 auto expected_r4_relaid =
618 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
619
620 auto input_literal = client_->TransferToServer(input_r4).value();
621 auto filter_literal = client_->TransferToServer(filter_r4).value();
622
623 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
624 {input_literal.get(), filter_literal.get()},
625 error_spec_, &expected_r4_relaid.shape());
626 }
627 };
628
629 TYPED_TEST_CASE(
630 Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
631 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,Types)632 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
633 Types) {
634 this->RunTest();
635 }
636
637 template <typename T>
638 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes
639 : public ConvolutionTest {
640 public:
RunTest()641 void RunTest() {
642 XlaBuilder builder(TestName());
643 std::vector<int64_t> input_dims = {256, 4, 4, 512};
644 std::vector<int64_t> filter_dims = {3, 3, 1, 512};
645 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
646 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
647 {
648 auto input = Parameter(&builder, 0, input_shape, "input");
649 auto filter = Parameter(&builder, 1, filter_shape, "filter");
650
651 // Tensorflow dimension numbers for 2D convolution.
652 ConvolutionDimensionNumbers dnums;
653 dnums.set_input_batch_dimension(0);
654 dnums.set_output_batch_dimension(0);
655 dnums.add_input_spatial_dimensions(1);
656 dnums.add_output_spatial_dimensions(1);
657 dnums.add_input_spatial_dimensions(2);
658 dnums.add_output_spatial_dimensions(2);
659 dnums.set_input_feature_dimension(3);
660 dnums.set_output_feature_dimension(3);
661 dnums.add_kernel_spatial_dimensions(0);
662 dnums.add_kernel_spatial_dimensions(1);
663 dnums.set_kernel_input_feature_dimension(2);
664 dnums.set_kernel_output_feature_dimension(3);
665
666 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
667 /*feature_group_count=*/512);
668 }
669
670 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
671 static_cast<T>(1));
672 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
673 auto input_r4 = input_r1.Reshape(input_dims).value();
674 auto input_r4_relaid =
675 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
676
677 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
678 static_cast<T>(2));
679 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
680 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
681
682 std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
683
684 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
685 auto expected_r4 = expected_r1.Reshape({256, 2, 2, 512}).value();
686
687 auto input_literal = client_->TransferToServer(input_r4_relaid).value();
688 auto filter_literal = client_->TransferToServer(filter_r4).value();
689
690 ComputeAndCompareLiteral(&builder, expected_r4,
691 {input_literal.get(), filter_literal.get()},
692 error_spec_);
693 }
694 };
695
696 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
697 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,Types)698 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
699 Types) {
700 this->RunTest();
701 }
702
703 template <typename T>
704 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes
705 : public ConvolutionTest {
706 public:
RunTest()707 void RunTest() {
708 XlaBuilder builder(TestName());
709 std::vector<int64_t> input_dims = {256, 4, 4, 512};
710 std::vector<int64_t> filter_dims = {3, 3, 1, 512};
711 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
712 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
713 {
714 auto input = Parameter(&builder, 0, input_shape, "input");
715 auto filter = Parameter(&builder, 1, filter_shape, "filter");
716
717 // Tensorflow dimension numbers for 2D convolution.
718 ConvolutionDimensionNumbers dnums;
719 dnums.set_input_batch_dimension(0);
720 dnums.set_output_batch_dimension(0);
721 dnums.add_input_spatial_dimensions(1);
722 dnums.add_output_spatial_dimensions(1);
723 dnums.add_input_spatial_dimensions(2);
724 dnums.add_output_spatial_dimensions(2);
725 dnums.set_input_feature_dimension(3);
726 dnums.set_output_feature_dimension(3);
727 dnums.add_kernel_spatial_dimensions(0);
728 dnums.add_kernel_spatial_dimensions(1);
729 dnums.set_kernel_input_feature_dimension(2);
730 dnums.set_kernel_output_feature_dimension(3);
731
732 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
733 /*feature_group_count=*/512);
734 }
735
736 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
737 static_cast<T>(1));
738 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
739 auto input_r4 = input_r1.Reshape(input_dims).value();
740 auto input_r4_relaid =
741 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
742
743 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
744 static_cast<T>(2));
745 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
746 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
747
748 std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
749
750 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
751 auto expected_r4 = expected_r1.Reshape({256, 2, 2, 512}).value();
752 auto expected_r4_relaid =
753 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
754
755 auto input_literal = client_->TransferToServer(input_r4_relaid).value();
756 auto filter_literal = client_->TransferToServer(filter_r4).value();
757
758 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
759 {input_literal.get(), filter_literal.get()},
760 error_spec_, &expected_r4_relaid.shape());
761 }
762 };
763
764 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
765 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,Types)766 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
767 Types) {
768 this->RunTest();
769 }
770
771 template <typename T>
772 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes
773 : public ConvolutionTest {
774 public:
RunTest()775 void RunTest() {
776 XlaBuilder builder(TestName());
777 std::vector<int64_t> input_dims = {1, 4, 4, 5};
778 std::vector<int64_t> filter_dims = {3, 3, 1, 5};
779 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
780 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
781 {
782 auto input = Parameter(&builder, 0, input_shape, "input");
783 auto filter = Parameter(&builder, 1, filter_shape, "filter");
784
785 // Tensorflow dimension numbers for 2D convolution.
786 ConvolutionDimensionNumbers dnums;
787 dnums.set_input_batch_dimension(0);
788 dnums.set_output_batch_dimension(0);
789 dnums.add_input_spatial_dimensions(1);
790 dnums.add_output_spatial_dimensions(1);
791 dnums.add_input_spatial_dimensions(2);
792 dnums.add_output_spatial_dimensions(2);
793 dnums.set_input_feature_dimension(3);
794 dnums.set_output_feature_dimension(3);
795 dnums.add_kernel_spatial_dimensions(0);
796 dnums.add_kernel_spatial_dimensions(1);
797 dnums.set_kernel_input_feature_dimension(2);
798 dnums.set_kernel_output_feature_dimension(3);
799
800 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
801 /*feature_group_count=*/5);
802 }
803
804 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
805 iota_int_init_value(input_elems, 1);
806 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
807 auto input_r4 = input_r1.Reshape(input_dims).value();
808 auto input_r4_relaid =
809 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
810
811 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
812 iota_int_init_value(filter_elems, 1);
813 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
814 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
815
816 auto expected_r1 = LiteralUtil::CreateR1<T>(
817 {static_cast<T>(6864), static_cast<T>(7296), static_cast<T>(7746),
818 static_cast<T>(8214), static_cast<T>(8700), static_cast<T>(7809),
819 static_cast<T>(8286), static_cast<T>(8781), static_cast<T>(9294),
820 static_cast<T>(9825), static_cast<T>(10644), static_cast<T>(11256),
821 static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
822 static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
823 static_cast<T>(13614), static_cast<T>(14325)});
824 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).value();
825 auto expected_r4_relaid =
826 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
827
828 auto input_literal = client_->TransferToServer(input_r4_relaid).value();
829 auto filter_literal = client_->TransferToServer(filter_r4).value();
830
831 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
832 {input_literal.get(), filter_literal.get()},
833 error_spec_, &expected_r4_relaid.shape());
834 }
835 };
836
837 TYPED_TEST_CASE(
838 Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
839 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,Types)840 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
841 Types) {
842 this->RunTest();
843 }
844
845 template <typename T>
846 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest {
847 public:
RunTest()848 void RunTest() {
849 XlaBuilder builder(TestName());
850 std::vector<int64_t> input_dims = {1, 4, 4, 160};
851 std::vector<int64_t> filter_dims = {3, 3, 1, 160};
852 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
853 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
854 {
855 auto input = Parameter(&builder, 0, input_shape, "input");
856 auto filter = Parameter(&builder, 1, filter_shape, "filter");
857
858 // Tensorflow dimension numbers for 2D convolution.
859 ConvolutionDimensionNumbers dnums;
860 dnums.set_input_batch_dimension(0);
861 dnums.set_output_batch_dimension(0);
862 dnums.add_input_spatial_dimensions(1);
863 dnums.add_output_spatial_dimensions(1);
864 dnums.add_input_spatial_dimensions(2);
865 dnums.add_output_spatial_dimensions(2);
866 dnums.set_input_feature_dimension(3);
867 dnums.set_output_feature_dimension(3);
868 dnums.add_kernel_spatial_dimensions(0);
869 dnums.add_kernel_spatial_dimensions(1);
870 dnums.set_kernel_input_feature_dimension(2);
871 dnums.set_kernel_output_feature_dimension(3);
872
873 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
874 /*feature_group_count=*/160);
875 }
876
877 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
878 static_cast<T>(1));
879 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
880 auto input_r4 = input_r1.Reshape(input_dims).value();
881
882 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
883 static_cast<T>(2));
884 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
885 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
886
887 std::vector<T> output_elems(640, static_cast<T>(18));
888
889 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
890 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).value();
891
892 auto input_literal = client_->TransferToServer(input_r4).value();
893 auto filter_literal = client_->TransferToServer(filter_r4).value();
894
895 ComputeAndCompareLiteral(&builder, expected_r4,
896 {input_literal.get(), filter_literal.get()},
897 error_spec_);
898 }
899 };
900
901 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid,Types)902 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) {
903 this->RunTest();
904 }
905
906 template <typename T>
907 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes
908 : public ConvolutionTest {
909 public:
RunTest()910 void RunTest() {
911 XlaBuilder builder(TestName());
912 std::vector<int64_t> input_dims = {1, 4, 4, 160};
913 std::vector<int64_t> filter_dims = {3, 3, 1, 160};
914 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
915 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
916 {
917 auto input = Parameter(&builder, 0, input_shape, "input");
918 auto filter = Parameter(&builder, 1, filter_shape, "filter");
919
920 // Tensorflow dimension numbers for 2D convolution.
921 ConvolutionDimensionNumbers dnums;
922 dnums.set_input_batch_dimension(0);
923 dnums.set_output_batch_dimension(0);
924 dnums.add_input_spatial_dimensions(1);
925 dnums.add_output_spatial_dimensions(1);
926 dnums.add_input_spatial_dimensions(2);
927 dnums.add_output_spatial_dimensions(2);
928 dnums.set_input_feature_dimension(3);
929 dnums.set_output_feature_dimension(3);
930 dnums.add_kernel_spatial_dimensions(0);
931 dnums.add_kernel_spatial_dimensions(1);
932 dnums.set_kernel_input_feature_dimension(2);
933 dnums.set_kernel_output_feature_dimension(3);
934
935 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
936 /*feature_group_count=*/160);
937 }
938
939 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
940 static_cast<T>(1));
941 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
942 auto input_r4 = input_r1.Reshape(input_dims).value();
943 auto input_r4_relaid =
944 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
945
946 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
947 static_cast<T>(2));
948 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
949 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
950
951 std::vector<T> output_elems(640, static_cast<T>(18));
952
953 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
954 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).value();
955 auto expected_r4_relaid =
956 expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1}));
957
958 auto input_literal = client_->TransferToServer(input_r4_relaid).value();
959 auto filter_literal = client_->TransferToServer(filter_r4).value();
960
961 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
962 {input_literal.get(), filter_literal.get()},
963 error_spec_, &expected_r4_relaid.shape());
964 }
965 };
966
967 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
968 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,Types)969 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
970 Types) {
971 this->RunTest();
972 }
973
974 template <typename T>
975 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes
976 : public ConvolutionTest {
977 public:
RunTest()978 void RunTest() {
979 XlaBuilder builder(TestName());
980 std::vector<int64_t> input_dims = {1, 4, 4, 160};
981 std::vector<int64_t> filter_dims = {3, 3, 1, 160};
982 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
983 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
984 {
985 auto input = Parameter(&builder, 0, input_shape, "input");
986 auto filter = Parameter(&builder, 1, filter_shape, "filter");
987
988 // Tensorflow dimension numbers for 2D convolution.
989 ConvolutionDimensionNumbers dnums;
990 dnums.set_input_batch_dimension(0);
991 dnums.set_output_batch_dimension(0);
992 dnums.add_input_spatial_dimensions(1);
993 dnums.add_output_spatial_dimensions(1);
994 dnums.add_input_spatial_dimensions(2);
995 dnums.add_output_spatial_dimensions(2);
996 dnums.set_input_feature_dimension(3);
997 dnums.set_output_feature_dimension(3);
998 dnums.add_kernel_spatial_dimensions(0);
999 dnums.add_kernel_spatial_dimensions(1);
1000 dnums.set_kernel_input_feature_dimension(2);
1001 dnums.set_kernel_output_feature_dimension(3);
1002
1003 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1004 /*feature_group_count=*/160);
1005 }
1006
1007 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1008 static_cast<T>(1));
1009 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1010 auto input_r4 = input_r1.Reshape(input_dims).value();
1011 auto input_r4_relaid =
1012 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1013
1014 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1015 static_cast<T>(2));
1016 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1017 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1018
1019 std::vector<T> output_elems(640, static_cast<T>(18));
1020
1021 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1022 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).value();
1023 auto expected_r4_relaid =
1024 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1025
1026 auto input_literal = client_->TransferToServer(input_r4_relaid).value();
1027 auto filter_literal = client_->TransferToServer(filter_r4).value();
1028
1029 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1030 {input_literal.get(), filter_literal.get()},
1031 error_spec_, &expected_r4_relaid.shape());
1032 }
1033 };
1034
1035 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes,
1036 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes,Types)1037 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes,
1038 Types) {
1039 this->RunTest();
1040 }
1041
1042 template <typename T>
1043 class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid
1044 : public ConvolutionTest {
1045 public:
RunTest()1046 void RunTest() {
1047 XlaBuilder builder(TestName());
1048 std::vector<int64_t> input_dims = {1, 4, 4, 1024};
1049 std::vector<int64_t> filter_dims = {3, 3, 1, 1024};
1050 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1051 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1052 {
1053 auto input = Parameter(&builder, 0, input_shape, "input");
1054 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1055
1056 // Tensorflow dimension numbers for 2D convolution.
1057 ConvolutionDimensionNumbers dnums;
1058 dnums.set_input_batch_dimension(0);
1059 dnums.set_output_batch_dimension(0);
1060 dnums.add_input_spatial_dimensions(1);
1061 dnums.add_output_spatial_dimensions(1);
1062 dnums.add_input_spatial_dimensions(2);
1063 dnums.add_output_spatial_dimensions(2);
1064 dnums.set_input_feature_dimension(3);
1065 dnums.set_output_feature_dimension(3);
1066 dnums.add_kernel_spatial_dimensions(0);
1067 dnums.add_kernel_spatial_dimensions(1);
1068 dnums.set_kernel_input_feature_dimension(2);
1069 dnums.set_kernel_output_feature_dimension(3);
1070
1071 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1072 /*feature_group_count=*/1024);
1073 }
1074
1075 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1076 static_cast<T>(1));
1077 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1078 auto input_r4 = input_r1.Reshape(input_dims).value();
1079
1080 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1081 static_cast<T>(2));
1082 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1083 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1084
1085 std::vector<T> output_elems(4096, static_cast<T>(18));
1086
1087 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1088 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).value();
1089
1090 auto input_literal = client_->TransferToServer(input_r4).value();
1091 auto filter_literal = client_->TransferToServer(filter_r4).value();
1092
1093 ComputeAndCompareLiteral(&builder, expected_r4,
1094 {input_literal.get(), filter_literal.get()},
1095 error_spec_);
1096 }
1097 };
1098
1099 TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid,Types)1100 TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) {
1101 this->RunTest();
1102 }
1103
1104 template <typename T>
1105 class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest {
1106 public:
RunTest()1107 void RunTest() {
1108 XlaBuilder builder(TestName());
1109 std::vector<int64_t> input_dims = {1, 2, 2, 6};
1110 std::vector<int64_t> filter_dims = {2, 2, 2, 12};
1111 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1112 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1113 {
1114 auto input = Parameter(&builder, 0, input_shape, "input");
1115 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1116
1117 // Tensorflow dimension numbers for 2D convolution.
1118 ConvolutionDimensionNumbers dnums;
1119 dnums.set_input_batch_dimension(0);
1120 dnums.set_output_batch_dimension(0);
1121 dnums.add_input_spatial_dimensions(1);
1122 dnums.add_output_spatial_dimensions(1);
1123 dnums.add_input_spatial_dimensions(2);
1124 dnums.add_output_spatial_dimensions(2);
1125 dnums.set_input_feature_dimension(3);
1126 dnums.set_output_feature_dimension(3);
1127 dnums.add_kernel_spatial_dimensions(0);
1128 dnums.add_kernel_spatial_dimensions(1);
1129 dnums.set_kernel_input_feature_dimension(2);
1130 dnums.set_kernel_output_feature_dimension(3);
1131
1132 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1133 /*feature_group_count=*/3);
1134 }
1135
1136 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1137 iota_int_init_value(input_elems, 1);
1138 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1139 auto input_r4 = input_r1.Reshape(input_dims).value();
1140
1141 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1142 iota_int_init_value(filter_elems, 1);
1143 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1144 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1145
1146 auto expected_r1 = LiteralUtil::CreateR1<T>(
1147 {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
1148 static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
1149 static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
1150 static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
1151 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).value();
1152
1153 auto input_literal = client_->TransferToServer(input_r4).value();
1154 auto filter_literal = client_->TransferToServer(filter_r4).value();
1155
1156 ComputeAndCompareLiteral(&builder, expected_r4,
1157 {input_literal.get(), filter_literal.get()},
1158 error_spec_);
1159 }
1160 };
1161
1162 TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid,Types)1163 TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) {
1164 this->RunTest();
1165 }
1166
1167 template <typename T>
1168 class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest {
1169 public:
RunTest()1170 void RunTest() {
1171 XlaBuilder builder(TestName());
1172 std::vector<int64_t> input_dims = {1, 2, 2, 1024};
1173 std::vector<int64_t> filter_dims = {2, 2, 128, 512};
1174 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1175 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1176 {
1177 auto input = Parameter(&builder, 0, input_shape, "input");
1178 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1179
1180 // Tensorflow dimension numbers for 2D convolution.
1181 ConvolutionDimensionNumbers dnums;
1182 dnums.set_input_batch_dimension(0);
1183 dnums.set_output_batch_dimension(0);
1184 dnums.add_input_spatial_dimensions(1);
1185 dnums.add_output_spatial_dimensions(1);
1186 dnums.add_input_spatial_dimensions(2);
1187 dnums.add_output_spatial_dimensions(2);
1188 dnums.set_input_feature_dimension(3);
1189 dnums.set_output_feature_dimension(3);
1190 dnums.add_kernel_spatial_dimensions(0);
1191 dnums.add_kernel_spatial_dimensions(1);
1192 dnums.set_kernel_input_feature_dimension(2);
1193 dnums.set_kernel_output_feature_dimension(3);
1194
1195 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1196 /*feature_group_count=*/8);
1197 }
1198
1199 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1200 static_cast<T>(1));
1201
1202 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1203 auto input_r4 = input_r1.Reshape(input_dims).value();
1204
1205 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1206 static_cast<T>(2));
1207
1208 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1209 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1210
1211 std::vector<T> output_elems(512, static_cast<T>(1024));
1212 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1213 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).value();
1214
1215 auto input_literal = client_->TransferToServer(input_r4).value();
1216 auto filter_literal = client_->TransferToServer(filter_r4).value();
1217
1218 ComputeAndCompareLiteral(&builder, expected_r4,
1219 {input_literal.get(), filter_literal.get()},
1220 error_spec_);
1221 }
1222 };
1223
1224 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid,Types)1225 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) {
1226 this->RunTest();
1227 }
1228
1229 template <typename T>
1230 class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest {
1231 public:
RunTest()1232 void RunTest() {
1233 XlaBuilder builder(TestName());
1234 std::vector<int64_t> input_dims = {1, 2, 2, 1024};
1235 std::vector<int64_t> filter_dims = {2, 2, 128, 8};
1236 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1237 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1238 {
1239 auto input = Parameter(&builder, 0, input_shape, "input");
1240 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1241
1242 // Tensorflow dimension numbers for 2D convolution.
1243 ConvolutionDimensionNumbers dnums;
1244 dnums.set_input_batch_dimension(0);
1245 dnums.set_output_batch_dimension(0);
1246 dnums.add_input_spatial_dimensions(1);
1247 dnums.add_output_spatial_dimensions(1);
1248 dnums.add_input_spatial_dimensions(2);
1249 dnums.add_output_spatial_dimensions(2);
1250 dnums.set_input_feature_dimension(3);
1251 dnums.set_output_feature_dimension(3);
1252 dnums.add_kernel_spatial_dimensions(0);
1253 dnums.add_kernel_spatial_dimensions(1);
1254 dnums.set_kernel_input_feature_dimension(2);
1255 dnums.set_kernel_output_feature_dimension(3);
1256
1257 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1258 /*feature_group_count=*/8);
1259 }
1260
1261 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1262 static_cast<T>(1));
1263
1264 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1265 auto input_r4 = input_r1.Reshape(input_dims).value();
1266
1267 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1268 static_cast<T>(2));
1269
1270 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1271 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1272
1273 std::vector<T> output_elems(8, static_cast<T>(1024));
1274 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1275 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).value();
1276
1277 auto input_literal = client_->TransferToServer(input_r4).value();
1278 auto filter_literal = client_->TransferToServer(filter_r4).value();
1279
1280 ComputeAndCompareLiteral(&builder, expected_r4,
1281 {input_literal.get(), filter_literal.get()},
1282 error_spec_);
1283 }
1284 };
1285
1286 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid,Types)1287 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) {
1288 this->RunTest();
1289 }
1290
1291 template <typename T>
1292 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest {
1293 public:
RunTest()1294 void RunTest() {
1295 XlaBuilder builder(TestName());
1296 std::vector<int64_t> input_dims = {1, 2, 2, 12};
1297 std::vector<int64_t> filter_dims = {2, 2, 3, 4};
1298 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1299 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1300 {
1301 auto input = Parameter(&builder, 0, input_shape, "input");
1302 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1303
1304 // Tensorflow dimension numbers for 2D convolution.
1305 ConvolutionDimensionNumbers dnums;
1306 dnums.set_input_batch_dimension(0);
1307 dnums.set_output_batch_dimension(0);
1308 dnums.add_input_spatial_dimensions(1);
1309 dnums.add_output_spatial_dimensions(1);
1310 dnums.add_input_spatial_dimensions(2);
1311 dnums.add_output_spatial_dimensions(2);
1312 dnums.set_input_feature_dimension(3);
1313 dnums.set_output_feature_dimension(3);
1314 dnums.add_kernel_spatial_dimensions(0);
1315 dnums.add_kernel_spatial_dimensions(1);
1316 dnums.set_kernel_input_feature_dimension(2);
1317 dnums.set_kernel_output_feature_dimension(3);
1318
1319 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1320 /*feature_group_count=*/4);
1321 }
1322
1323 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1324 iota_int_init_value(input_elems, 1);
1325 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1326 auto input_r4 = input_r1.Reshape(input_dims).value();
1327
1328 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1329 iota_int_init_value(filter_elems, 1);
1330 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1331 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1332
1333 auto expected_r1 =
1334 LiteralUtil::CreateR1<T>({static_cast<T>(7712), static_cast<T>(8816),
1335 static_cast<T>(9992), static_cast<T>(11240)});
1336 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).value();
1337
1338 auto input_literal = client_->TransferToServer(input_r4).value();
1339 auto filter_literal = client_->TransferToServer(filter_r4).value();
1340
1341 ComputeAndCompareLiteral(&builder, expected_r4,
1342 {input_literal.get(), filter_literal.get()},
1343 error_spec_);
1344 }
1345 };
1346
1347 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid,Types)1348 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) {
1349 this->RunTest();
1350 }
1351
1352 template <typename T>
1353 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes
1354 : public ConvolutionTest {
1355 public:
RunTest()1356 void RunTest() {
1357 XlaBuilder builder(TestName());
1358 std::vector<int64_t> input_dims = {1, 2, 2, 12};
1359 std::vector<int64_t> filter_dims = {2, 2, 4, 3};
1360 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1361 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1362 {
1363 auto input = Parameter(&builder, 0, input_shape, "input");
1364 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1365
1366 // Tensorflow dimension numbers for 2D convolution.
1367 ConvolutionDimensionNumbers dnums;
1368 dnums.set_input_batch_dimension(0);
1369 dnums.set_output_batch_dimension(0);
1370 dnums.add_input_spatial_dimensions(1);
1371 dnums.add_output_spatial_dimensions(1);
1372 dnums.add_input_spatial_dimensions(2);
1373 dnums.add_output_spatial_dimensions(2);
1374 dnums.set_input_feature_dimension(3);
1375 dnums.set_output_feature_dimension(3);
1376 dnums.add_kernel_spatial_dimensions(0);
1377 dnums.add_kernel_spatial_dimensions(1);
1378 dnums.set_kernel_input_feature_dimension(3);
1379 dnums.set_kernel_output_feature_dimension(2);
1380
1381 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1382 /*feature_group_count=*/4);
1383 }
1384
1385 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1386 iota_int_init_value(input_elems, 1);
1387 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1388 auto input_r4 = input_r1.Reshape(input_dims).value();
1389
1390 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1391 iota_int_init_value(filter_elems, 1);
1392 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1393 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1394 auto filter_r4_relaid =
1395 filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
1396 auto expected_r1 = LiteralUtil::CreateR1<T>(
1397 {static_cast<T>(6968), static_cast<T>(8516), static_cast<T>(10280),
1398 static_cast<T>(12260)});
1399 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).value();
1400
1401 auto input_literal = client_->TransferToServer(input_r4).value();
1402 auto filter_literal = client_->TransferToServer(filter_r4_relaid).value();
1403
1404 ComputeAndCompareLiteral(&builder, expected_r4,
1405 {input_literal.get(), filter_literal.get()},
1406 error_spec_);
1407 }
1408 };
1409
1410 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1411 TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,Types)1412 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1413 Types) {
1414 this->RunTest();
1415 }
1416
1417 template <typename T>
1418 class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest {
1419 public:
RunTest()1420 void RunTest() {
1421 XlaBuilder builder(TestName());
1422 std::vector<int64_t> input_dims = {1, 1, 1, 12};
1423 std::vector<int64_t> filter_dims = {1, 1, 3, 4};
1424 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1425 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1426 {
1427 auto input = Parameter(&builder, 0, input_shape, "input");
1428 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1429
1430 // Tensorflow dimension numbers for 2D convolution.
1431 ConvolutionDimensionNumbers dnums;
1432 dnums.set_input_batch_dimension(0);
1433 dnums.set_output_batch_dimension(0);
1434 dnums.add_input_spatial_dimensions(1);
1435 dnums.add_output_spatial_dimensions(1);
1436 dnums.add_input_spatial_dimensions(2);
1437 dnums.add_output_spatial_dimensions(2);
1438 dnums.set_input_feature_dimension(3);
1439 dnums.set_output_feature_dimension(3);
1440 dnums.add_kernel_spatial_dimensions(0);
1441 dnums.add_kernel_spatial_dimensions(1);
1442 dnums.set_kernel_input_feature_dimension(2);
1443 dnums.set_kernel_output_feature_dimension(3);
1444
1445 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1446 /*feature_group_count=*/4);
1447 }
1448
1449 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1450 iota_int_init_value(input_elems, 1);
1451 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1452 auto input_r4 = input_r1.Reshape(input_dims).value();
1453
1454 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1455 iota_int_init_value(filter_elems, 1);
1456 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1457 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1458
1459 auto expected_r1 =
1460 LiteralUtil::CreateR1<T>({static_cast<T>(38), static_cast<T>(98),
1461 static_cast<T>(176), static_cast<T>(272)});
1462 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).value();
1463
1464 auto input_literal = client_->TransferToServer(input_r4).value();
1465 auto filter_literal = client_->TransferToServer(filter_r4).value();
1466
1467 ComputeAndCompareLiteral(&builder, expected_r4,
1468 {input_literal.get(), filter_literal.get()},
1469 error_spec_);
1470 }
1471 };
1472
1473 TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid,Types)1474 TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) {
1475 this->RunTest();
1476 }
1477
1478 // Test fixture to run convolution tests with and without convolution
1479 // canonicalization enabled.
1480 class ConvolveWithAndWithoutCanonicalization
1481 : public ConvolutionTest,
1482 public ::testing::WithParamInterface<bool> {};
1483
XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,DISABLED_ON_GPU (Convolve2D_NoSpatialDims))1484 XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
1485 DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) {
1486 if (GetParam()) {
1487 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1488 "convolution-canonicalization");
1489 }
1490 XlaBuilder builder(TestName());
1491 Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
1492 Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
1493
1494 auto input = Parameter(&builder, 0, input_shape, "input");
1495 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1496
1497 ConvolutionDimensionNumbers dnums;
1498 dnums.set_input_feature_dimension(0);
1499 dnums.set_input_batch_dimension(1);
1500 dnums.set_kernel_input_feature_dimension(0);
1501 dnums.set_kernel_output_feature_dimension(1);
1502 dnums.set_output_batch_dimension(0);
1503 dnums.set_output_feature_dimension(1);
1504 ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
1505
1506 Array2D<float> param0(4, 29);
1507 param0.FillUnique();
1508
1509 Array2D<float> param1(4, 10);
1510 param1.FillUnique();
1511
1512 Array2D<float> expected_result(29, 10);
1513 expected_result.Fill(0);
1514
1515 ComputeAndCompare(&builder,
1516 {LiteralUtil::CreateFromArray(param0),
1517 LiteralUtil::CreateFromArray(param1)},
1518 error_spec_);
1519 }
1520
1521 INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
1522 ConvolveWithAndWithoutCanonicalization,
1523 ::testing::Values(true, false));
1524
XLA_TEST_F(ConvolutionTest,Convolve_bf16_1x1x1x2_1x1x1x2_Valid)1525 XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
1526 XlaBuilder builder(TestName());
1527 Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1528 Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1529 auto input = Parameter(&builder, 0, input_shape, "input");
1530 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1531 Conv(input, filter, {1, 1}, Padding::kValid);
1532
1533 Array4D<bfloat16> input_data(1, 1, 1, 2);
1534 input_data.FillWithYX(Array2D<bfloat16>({
1535 {bfloat16(1), bfloat16(2)},
1536 }));
1537 Array4D<bfloat16> filter_data(1, 1, 1, 2);
1538 filter_data.FillWithYX(Array2D<bfloat16>({
1539 {bfloat16(5), bfloat16(6)},
1540 }));
1541
1542 ComputeAndCompare(&builder,
1543 {LiteralUtil::CreateFromArray(input_data),
1544 LiteralUtil::CreateFromArray(filter_data)},
1545 error_spec_);
1546 }
1547
1548 // Check that GPU convs still work if the CudnnAlgorithmPicker pass is disabled.
1549 // (We run this test on all platforms, because, what the heck.)
XLA_TEST_F(ConvolutionTest,DISABLED_ON_GPU_ROCM (NoCudnnAlgorithmPicker))1550 XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU_ROCM(NoCudnnAlgorithmPicker)) {
1551 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1552 "gpu-conv-algorithm-picker");
1553
1554 XlaBuilder builder(TestName());
1555 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1556 Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1557 auto input = Parameter(&builder, 0, input_shape, "input");
1558 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1559 Conv(input, filter, {1, 1}, Padding::kValid);
1560
1561 Array4D<float> input_data(1, 1, 1, 2);
1562 input_data.FillIota(0);
1563 Array4D<float> filter_data(1, 1, 1, 2);
1564 filter_data.FillIota(10);
1565
1566 ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
1567 LiteralUtil::CreateFromArray(filter_data)});
1568 }
1569
XLA_TEST_F(ConvolutionTest,ConvolveF32BackwardInputGroupedConvolution)1570 XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
1571 XlaBuilder builder(TestName());
1572 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
1573 Array4D<float> input_data(1, 64, 100, 100);
1574 input_data.FillRandom(/*stddev=*/0.023, 0.001, /*seed=*/45321);
1575 Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
1576 Array4D<float> filter_data(7, 7, 1, 64);
1577 filter_data.FillRandom(/*stddev=*/0.023, 0.001, /*seed=*/45320);
1578 auto input = Parameter(&builder, 0, input_shape, "input");
1579 auto filter = ConstantR4FromArray4D(&builder, filter_data);
1580
1581 // Specify bf01_01io->bf01 as dimension numbers.
1582 ConvolutionDimensionNumbers dnums;
1583 // Input
1584 dnums.set_input_feature_dimension(1);
1585 dnums.set_input_batch_dimension(0);
1586 dnums.add_input_spatial_dimensions(2);
1587 dnums.add_input_spatial_dimensions(3);
1588 // Kernel
1589 dnums.set_kernel_input_feature_dimension(2);
1590 dnums.set_kernel_output_feature_dimension(3);
1591 dnums.add_kernel_spatial_dimensions(0);
1592 dnums.add_kernel_spatial_dimensions(1);
1593 // Output
1594 dnums.set_output_batch_dimension(0);
1595 dnums.set_output_feature_dimension(1);
1596 dnums.add_output_spatial_dimensions(2);
1597 dnums.add_output_spatial_dimensions(3);
1598 ConvGeneral(input, filter, /*window_strides=*/{1, 1},
1599 /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
1600 /*feature_group_count=*/64);
1601
1602 ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
1603 error_spec_);
1604 }
1605
1606 class ConvolutionHloTest : public HloTestBase {};
1607
1608 // double datatype is not yet supported in ROCm
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF64Forward))1609 XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64Forward)) {
1610 constexpr char kHlo[] = R"(
1611 HloModule TestModule
1612
1613 ENTRY Test {
1614 %arg0 = f64[3,56,56,16] parameter(0)
1615 %arg1 = f64[3,3,3,64] parameter(1)
1616 ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
1617 })";
1618 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1619 }
1620
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU (ConvolveC64Forward))1621 XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(ConvolveC64Forward)) {
1622 constexpr char kHlo[] = R"(
1623 HloModule TestModule
1624
1625 ENTRY Test {
1626 %arg0 = c64[3,56,56,16] parameter(0)
1627 %arg1 = c64[3,3,3,64] parameter(1)
1628 ROOT %conv = c64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
1629 })";
1630 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1631 }
1632
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF32ForwardReversed))1633 XLA_TEST_F(ConvolutionHloTest,
1634 DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) {
1635 constexpr char kHlo[] = R"(
1636 HloModule TestModule
1637
1638 ENTRY Test {
1639 %arg0 = f32[3,56,56,16] parameter(0)
1640 %arg1 = f32[3,3,3,32] parameter(1)
1641 ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf
1642 })";
1643 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1644 }
1645
1646 // double datatype is not yet supported in ROCm
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF64BackwardFilter))1647 XLA_TEST_F(ConvolutionHloTest,
1648 DISABLED_ON_GPU_ROCM(ConvolveF64BackwardFilter)) {
1649 constexpr char kHlo[] = R"(
1650 HloModule TestModule
1651
1652 ENTRY Test {
1653 %arg0 = f64[2,5,8,1] parameter(0)
1654 %arg1 = f64[2,5,8,2] parameter(1)
1655 ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf
1656 })";
1657 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1658 }
1659
1660 // double datatype is not yet supported in ROCm
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF64BackwardInput))1661 XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64BackwardInput)) {
1662 constexpr char kHlo[] = R"(
1663 HloModule TestModule
1664
1665 ENTRY Test {
1666 %output = f64[4,5,16,16] parameter(0)
1667 %kernel = f64[5,3,7,7] parameter(1)
1668 %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3}
1669 ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01
1670 })";
1671 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1672 }
1673
XLA_TEST_F(ConvolutionHloTest,ConvolveBackwardInput)1674 XLA_TEST_F(ConvolutionHloTest, ConvolveBackwardInput) {
1675 constexpr char kHlo[] = R"(
1676 HloModule TestModule
1677
1678 ENTRY Test {
1679 %output = f32[3,3,64,64] parameter(0)
1680 %kernel = f32[672,7,7,64] parameter(1)
1681 %reverse = f32[672,7,7,64]{3,2,1,0} reverse(f32[672,7,7,64]{3,2,1,0} %kernel), dimensions={1,2}
1682 ROOT %convolution = f32[672,9,9,64]{3,2,1,0} convolution(f32[3,3,64,64]{3,2,1,0} %output, f32[672,7,7,64]{3,2,1,0} %reverse), window={size=7x7 pad=6_6x6_6}, dim_labels=01bf_o01i->f01b
1683 })";
1684 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1685 }
1686
XLA_TEST_F(ConvolutionHloTest,SwappedOperandConvolve)1687 XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve) {
1688 constexpr char kHlo[] = R"(
1689 HloModule TestModule
1690
1691 ENTRY Test {
1692 %lhs = f32[3,3,7,7] parameter(0)
1693 %rhs = f32[5,11,11,7] parameter(1)
1694 ROOT %convolution = f32[5,21,2,7] convolution(lhs, rhs),
1695 window={size=11x11 pad=3_25x3_6},
1696 dim_labels=01bf_o01i->f01b
1697 })";
1698 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1699 }
1700
XLA_TEST_F(ConvolutionHloTest,SwappedOperandConvolveWithStride)1701 XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolveWithStride) {
1702 constexpr char kHlo[] = R"(
1703 HloModule TestModule
1704
1705 ENTRY Test {
1706 %lhs = f32[3,3,7,7] parameter(0)
1707 %rhs = f32[5,11,11,7] parameter(1)
1708 ROOT %convolution = f32[5,11,2,7] convolution(lhs, rhs),
1709 window={size=11x11 pad=3_26x3_6 stride=2x1},
1710 dim_labels=01bf_o01i->f01b
1711 })";
1712 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1713 }
XLA_TEST_F(ConvolutionHloTest,SwappedOperandConvolve2)1714 XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve2) {
1715 constexpr char kHlo[] = R"(
1716 HloModule TestModule
1717
1718 ENTRY Test {
1719 %lhs = f32[3,3,7,7] parameter(0)
1720 %rhs = f32[5,11,11,7] parameter(1)
1721 ROOT %convolution = f32[5,11,4,7] convolution(lhs, rhs),
1722 window={size=11x11 pad=3_25x3_6 lhs_dilate=1x2 rhs_dilate=2x1},
1723 dim_labels=01bf_o01i->f01b
1724 })";
1725 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1726 }
1727
XLA_TEST_F(ConvolutionHloTest,TestConv0D)1728 XLA_TEST_F(ConvolutionHloTest, TestConv0D) {
1729 constexpr char kHlo[] = R"(
1730 HloModule TestModule
1731
1732 ENTRY TestComputation {
1733 %parameter.1 = f32[10,5]{1,0} parameter(0)
1734 %parameter.2 = f32[5,7]{1,0} parameter(1)
1735 ROOT %convolution.3 = f32[10,7]{1,0} convolution(f32[10,5]{1,0} %parameter.1, f32[5,7]{1,0} %parameter.2), dim_labels=bf_io->bf
1736 })";
1737 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1738 }
1739
XLA_TEST_F(ConvolutionHloTest,TestFusedConv2D)1740 XLA_TEST_F(ConvolutionHloTest, TestFusedConv2D) {
1741 constexpr char kHlo[] = R"(
1742 HloModule TestModule
1743
1744 ENTRY TestComputation {
1745 %p0 = f32[8,5,5,1] parameter(0)
1746 %p1 = f32[3,3,1,32] parameter(1)
1747 %conv = f32[8,5,5,32] convolution(p0, p1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
1748 %bias = f32[32] parameter(2)
1749 %broadcasted_bias = f32[8,5,5,32] broadcast(%bias), dimensions={3}
1750 %add = f32[8,5,5,32] add(%conv, %broadcasted_bias)
1751 %zero = f32[] constant(0)
1752 %zeros = f32[8,5,5,32] broadcast(%zero), dimensions={}
1753 ROOT relu = f32[8,5,5,32] maximum(%zeros, %add)
1754 })";
1755 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1756 }
1757
XLA_TEST_F(ConvolutionHloTest,TestFusedConv3D)1758 XLA_TEST_F(ConvolutionHloTest, TestFusedConv3D) {
1759 constexpr char kHlo[] = R"(
1760 HloModule TestModule
1761
1762 ENTRY TestComputation {
1763 %p0 = f32[8,4,5,5,1] parameter(0)
1764 %p1 = f32[3,3,3,1,32] parameter(1)
1765 %conv = f32[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
1766 %bias = f32[32] parameter(2)
1767 %broadcasted_bias = f32[8,4,5,5,32] broadcast(%bias), dimensions={4}
1768 %add = f32[8,4,5,5,32] add(%conv, %broadcasted_bias)
1769 %zero = f32[] constant(0)
1770 %zeros = f32[8,4,5,5,32] broadcast(%zero), dimensions={}
1771 ROOT relu = f32[8,4,5,5,32] maximum(%zeros, %add)
1772 })";
1773 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1774 }
1775
XLA_TEST_F(ConvolutionHloTest,TestBooleanInput)1776 XLA_TEST_F(ConvolutionHloTest, TestBooleanInput) {
1777 constexpr char kHlo[] = R"(
1778 HloModule TestModule
1779
1780 ENTRY TestComputation {
1781 constant.1 = pred[] constant(true)
1782 broadcast.2 = pred[3,3,3]{2,1,0} broadcast(constant.1), dimensions={}
1783 convolution.3 = pred[3,3,3]{2,1,0} convolution(broadcast.2, broadcast.2), window={size=3 pad=1_1}, dim_labels=bf0_oi0->bf0
1784 ROOT tuple.4 = (pred[3,3,3]{2,1,0}) tuple(convolution.3)
1785 })";
1786 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1787 }
1788
1789 } // namespace
1790 } // namespace xla
1791