1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests of convolution with trivial kernels and no special variations (like
17 // strides and padding).
18
19 #include <memory>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/array2d.h"
24 #include "tensorflow/compiler/xla/array4d.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/client/padding.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/reference_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace xla {
43 namespace {
44
45 class ConvolutionTest : public ClientLibraryTestBase {
46 protected:
47 #if XLA_TEST_BACKEND_GPU
48 // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
49 // convolution. So relax the absolute error threshold.
50 ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4);
51 #else
52 ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4);
53 #endif
54 };
55
56 #ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
57 using TestTypes = ::testing::Types<float>;
58 #else
59 using TestTypes = ::testing::Types<float, Eigen::half>;
60 #endif
61
62 template <typename T>
63 class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
64 public:
RunTest()65 void RunTest() {
66 const int kInputActivationSizeY = 3;
67 const int kInputActivationSizeX = 3;
68 const int kInputActivationSizeZ = 256;
69 const int kKernelSizeX = 2;
70 const int kKernelSizeY = 2;
71 const int kOutputActivationSizeZ = 256;
72 const int kMiniBatchSize = 4;
73 auto alhs = absl::make_unique<Array4D<T>>(
74 kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY,
75 kInputActivationSizeX);
76 alhs->FillWithMultiples(static_cast<T>(1.0f));
77 ASSERT_EQ(3, alhs->width());
78 ASSERT_EQ(3, alhs->height());
79
80 auto arhs = absl::make_unique<Array4D<T>>(kOutputActivationSizeZ,
81 kInputActivationSizeZ,
82 kKernelSizeY, kKernelSizeX);
83 Array2D<T> rhs_raster({
84 {1.0f, 0.0f}, // row 0
85 {0.0f, 0.0f}, // row 1
86 });
87 arhs->FillWithYX(rhs_raster);
88 ASSERT_EQ(2, arhs->width());
89 ASSERT_EQ(2, arhs->height());
90
91 XlaBuilder builder(TestName());
92 auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
93 auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
94 PrecisionConfig precision;
95 // The left hand side of the convolution is numbers between 0 and 2304 which
96 // requires at least 11 mantissa bits and the DEFAULT precision config is
97 // allowed to round to bfloat16 which only has 7 mantissa bits.
98 precision.add_operand_precision(PrecisionConfig::HIGHEST);
99 precision.add_operand_precision(PrecisionConfig::DEFAULT);
100 Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
101 /*batch_group_count=*/1, &precision);
102
103 ComputeAndCompare(&builder, {}, error_spec_);
104 }
105 };
106
107 TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota,Types)108 XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
109 this->RunTest();
110 }
111
112 template <typename T>
113 class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
114 public:
RunTest()115 void RunTest() {
116 XlaBuilder builder(TestName());
117 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
118 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
119 auto input = Parameter(&builder, 0, input_shape, "input");
120 auto filter = Parameter(&builder, 1, filter_shape, "filter");
121 Conv(input, filter, {1, 1}, Padding::kValid);
122
123 Array4D<T> input_data(1, 1, 1, 2);
124 input_data.FillWithYX(Array2D<T>({
125 {1.0f, 2.0f},
126 }));
127 Array4D<T> filter_data(1, 1, 1, 2);
128 filter_data.FillWithYX(Array2D<T>({
129 {5.0f, 6.0f},
130 }));
131
132 ComputeAndCompare(&builder,
133 {LiteralUtil::CreateFromArray(input_data),
134 LiteralUtil::CreateFromArray(filter_data)},
135 error_spec_);
136 }
137 };
138
139 TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid,Types)140 TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
141
142 // Tests valid padding for 2D convolution in raster space.
143 template <typename T>
144 class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
145 public:
RunTest()146 void RunTest() {
147 XlaBuilder builder(TestName());
148 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
149 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
150 auto input = Parameter(&builder, 0, input_shape, "input");
151 auto filter = Parameter(&builder, 1, filter_shape, "filter");
152 Conv(input, filter, {1, 1}, Padding::kValid);
153
154 Array4D<T> input_data(1, 1, 4, 4);
155 input_data.FillWithYX(Array2D<T>({
156 {1.0f, 2.0f, 3.0f, 4.0f},
157 {5.0f, 6.0f, 7.0f, 8.0f},
158 {9.0f, 10.0f, 11.0f, 12.0f},
159 {13.0f, 14.0f, 15.0f, 16.0f},
160 }));
161 Array4D<T> filter_data(1, 1, 2, 2);
162 filter_data.FillWithYX(Array2D<T>({
163 {5.0f, 6.0f},
164 {7.0f, 8.0f},
165 }));
166 ComputeAndCompare(&builder,
167 {LiteralUtil::CreateFromArray(input_data),
168 LiteralUtil::CreateFromArray(filter_data)},
169 error_spec_);
170 }
171 };
172
173 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid,Types)174 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
175
176 // Tests same padding for 2D convolution in raster space.
177 template <typename T>
178 class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
179 public:
RunTest()180 void RunTest() {
181 XlaBuilder builder(TestName());
182 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
183 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
184 auto input = Parameter(&builder, 0, input_shape, "input");
185 auto filter = Parameter(&builder, 1, filter_shape, "filter");
186 Conv(input, filter, {1, 1}, Padding::kSame);
187
188 Array4D<T> input_data(1, 1, 4, 4);
189 input_data.FillWithYX(Array2D<T>({
190 {1.0f, 2.0f, 3.0f, 4.0f},
191 {5.0f, 6.0f, 7.0f, 8.0f},
192 {9.0f, 10.0f, 11.0f, 12.0f},
193 {13.0f, 14.0f, 15.0f, 16.0f},
194 }));
195 Array4D<T> filter_data(1, 1, 2, 2);
196 filter_data.FillWithYX(Array2D<T>({
197 {5.0f, 6.0f},
198 {7.0f, 8.0f},
199 }));
200
201 ComputeAndCompare(&builder,
202 {LiteralUtil::CreateFromArray(input_data),
203 LiteralUtil::CreateFromArray(filter_data)},
204 error_spec_);
205 }
206 };
207
208 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same,Types)209 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
210
211 // Tests same padding for 2D convolution in raster space with an odd sized
212 // kernel.
213 template <typename T>
214 class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
215 public:
RunTest()216 void RunTest() {
217 XlaBuilder builder(TestName());
218 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
219 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3});
220 auto input = Parameter(&builder, 0, input_shape, "input");
221 auto filter = Parameter(&builder, 1, filter_shape, "filter");
222 Conv(input, filter, {1, 1}, Padding::kSame);
223
224 Array4D<T> input_data(1, 1, 4, 4);
225 input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
226 {5.0f, 6.0f, 7.0f, 8.0f},
227 {9.0f, 10.0f, 11.0f, 12.0f},
228 {13.0f, 14.0f, 15.0f, 16.0f}}));
229 Array4D<T> filter_data(1, 1, 3, 3);
230 filter_data.FillWithYX(Array2D<T>(
231 {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
232 // clang-format on
233 ComputeAndCompare(&builder,
234 {LiteralUtil::CreateFromArray(input_data),
235 LiteralUtil::CreateFromArray(filter_data)},
236 error_spec_);
237 }
238 };
239
240 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same,Types)241 TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
242
XLA_TEST_F(ConvolutionTest,Convolve1D_1x2x5_1x2x2_Valid)243 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
244 XlaBuilder builder(TestName());
245 {
246 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
247 Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
248 auto input = Parameter(&builder, 0, input_shape, "input");
249 auto filter = Parameter(&builder, 1, filter_shape, "filter");
250 Conv(input, filter, {1}, Padding::kValid);
251 }
252
253 Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
254 Array3D<float> filter({{{10, 20}, {30, 40}}});
255
256 Array3D<float> expected({{{510, 610, 710, 810}}});
257
258 auto input_literal =
259 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
260 .ConsumeValueOrDie();
261 auto filter_literal =
262 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
263 .ConsumeValueOrDie();
264
265 ComputeAndCompareR3<float>(&builder, expected,
266 {input_literal.get(), filter_literal.get()},
267 error_spec_);
268 }
269
270 template <typename T>
271 class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
272 public:
RunTest()273 void RunTest() {
274 XlaBuilder builder(TestName());
275 {
276 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
277 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
278 auto input = Parameter(&builder, 0, input_shape, "input");
279 auto filter = Parameter(&builder, 1, filter_shape, "filter");
280 // Convolution dimensions are bf0_oi0->bo0.
281 ConvGeneralDilated(
282 input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
283 /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
284 /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
285 }
286
287 Array3D<T> input(
288 {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
289 Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
290
291 Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
292
293 auto input_literal =
294 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
295 .ConsumeValueOrDie();
296 auto filter_literal =
297 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
298 .ConsumeValueOrDie();
299
300 ComputeAndCompareR3<T>(&builder, expected,
301 {input_literal.get(), filter_literal.get()},
302 error_spec_);
303 }
304 }; // namespace
305
306 TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation,Types)307 TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
308
XLA_TEST_F(ConvolutionTest,Convolve1D_1x2x5_1x2x2_WithLHSDilation)309 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
310 XlaBuilder builder(TestName());
311 {
312 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
313 Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
314 auto input = Parameter(&builder, 0, input_shape, "input");
315 auto filter = Parameter(&builder, 1, filter_shape, "filter");
316 // Convolution dimensions are bf0_oi0->bo0.
317 ConvGeneralDilated(
318 input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
319 /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
320 /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
321 }
322
323 Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
324 Array3D<float> filter({{{10, 20}, {30, 40}}});
325
326 Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
327
328 auto input_literal =
329 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
330 .ConsumeValueOrDie();
331 auto filter_literal =
332 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
333 .ConsumeValueOrDie();
334
335 ComputeAndCompareR3<float>(&builder, expected,
336 {input_literal.get(), filter_literal.get()},
337 error_spec_);
338 }
339
XLA_TEST_F(ConvolutionTest,Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation)340 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
341 XlaBuilder builder(TestName());
342 {
343 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
344 Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
345 auto input = Parameter(&builder, 0, input_shape, "input");
346 auto filter = Parameter(&builder, 1, filter_shape, "filter");
347 // Convolution dimensions are bf0_oi0->bo0.
348 ConvGeneralDilated(
349 input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
350 /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
351 /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
352 }
353
354 Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
355 Array3D<float> filter({{{10, 20}, {30, 40}}});
356
357 Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
358
359 auto input_literal =
360 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
361 .ConsumeValueOrDie();
362 auto filter_literal =
363 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
364 .ConsumeValueOrDie();
365
366 ComputeAndCompareR3<float>(&builder, expected,
367 {input_literal.get(), filter_literal.get()},
368 error_spec_);
369 }
370
371 template <typename T>
372 class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
373 public:
RunTest()374 void RunTest() {
375 XlaBuilder builder(TestName());
376 {
377 Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
378 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
379 auto input = Parameter(&builder, 0, input_shape, "input");
380 auto filter = Parameter(&builder, 1, filter_shape, "filter");
381 // Convolution dimensions are bf0_oi0->bo0.
382 ConvGeneralDilated(
383 input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
384 /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
385 /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
386 }
387
388 Array3D<T> input(
389 {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
390 Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
391
392 Array3D<T> expected(
393 {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
394
395 auto input_literal =
396 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
397 .ConsumeValueOrDie();
398 auto filter_literal =
399 client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
400 .ConsumeValueOrDie();
401
402 ComputeAndCompareR3<T>(&builder, expected,
403 {input_literal.get(), filter_literal.get()},
404 error_spec_);
405 }
406 };
407
408 TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding,Types)409 TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
410
XLA_TEST_F(ConvolutionTest,Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)411 XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
412 XlaBuilder builder(TestName());
413 std::vector<int64> input_dims = {1, 4, 2, 3, 3};
414 std::vector<int64> filter_dims = {2, 2, 2, 3, 3};
415 Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
416 Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
417 {
418 auto input = Parameter(&builder, 0, input_shape, "input");
419 auto filter = Parameter(&builder, 1, filter_shape, "filter");
420
421 // Tensorflow dimension numbers for 3D convolution.
422 ConvolutionDimensionNumbers dnums;
423 dnums.set_input_batch_dimension(0);
424 dnums.set_output_batch_dimension(0);
425 dnums.add_input_spatial_dimensions(1);
426 dnums.add_output_spatial_dimensions(1);
427 dnums.add_input_spatial_dimensions(2);
428 dnums.add_output_spatial_dimensions(2);
429 dnums.add_input_spatial_dimensions(3);
430 dnums.add_output_spatial_dimensions(3);
431 dnums.set_input_feature_dimension(4);
432 dnums.set_output_feature_dimension(4);
433 dnums.add_kernel_spatial_dimensions(0);
434 dnums.add_kernel_spatial_dimensions(1);
435 dnums.add_kernel_spatial_dimensions(2);
436 dnums.set_kernel_input_feature_dimension(3);
437 dnums.set_kernel_output_feature_dimension(4);
438
439 ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums);
440 }
441
442 std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
443 iota(input_elems.begin(), input_elems.end(), 1.0f);
444 auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
445 auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
446
447 std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
448 iota(filter_elems.begin(), filter_elems.end(), 1.0f);
449 auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
450 auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
451
452 auto expected_r1 = LiteralUtil::CreateR1<float>(
453 {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
454 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
455 auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
456
457 auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
458 auto filter_literal =
459 client_->TransferToServer(filter_r5).ConsumeValueOrDie();
460
461 ComputeAndCompareLiteral(&builder, expected_r5,
462 {input_literal.get(), filter_literal.get()},
463 error_spec_);
464 }
465
466 // std::iota doesn't work when init_value has a type Eigen::half in some build
467 // servers. The error message is missing the operator ++.
468 template <typename T>
iota_int_init_value(std::vector<T> & values,int init_value)469 void iota_int_init_value(std::vector<T>& values, int init_value) {
470 absl::c_for_each(values,
471 [&](T& value) { value = static_cast<T>(init_value++); });
472 }
473
474 template <typename T>
475 class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
476 public:
RunTest()477 void RunTest() {
478 XlaBuilder builder(TestName());
479 std::vector<int64> input_dims = {1, 3, 3, 5};
480 std::vector<int64> filter_dims = {3, 3, 5, 3};
481 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
482 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
483 {
484 auto input = Parameter(&builder, 0, input_shape, "input");
485 auto filter = Parameter(&builder, 1, filter_shape, "filter");
486
487 // Tensorflow dimension numbers for 2D convolution.
488 ConvolutionDimensionNumbers dnums;
489 dnums.set_input_batch_dimension(0);
490 dnums.set_output_batch_dimension(0);
491 dnums.add_input_spatial_dimensions(1);
492 dnums.add_output_spatial_dimensions(1);
493 dnums.add_input_spatial_dimensions(2);
494 dnums.add_output_spatial_dimensions(2);
495 dnums.set_input_feature_dimension(3);
496 dnums.set_output_feature_dimension(3);
497 dnums.add_kernel_spatial_dimensions(0);
498 dnums.add_kernel_spatial_dimensions(1);
499 dnums.set_kernel_input_feature_dimension(2);
500 dnums.set_kernel_output_feature_dimension(3);
501
502 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums);
503 }
504
505 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
506 iota_int_init_value(input_elems, 1);
507 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
508 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
509
510 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
511 iota_int_init_value(filter_elems, 1);
512 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
513 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
514
515 auto expected_r1 = LiteralUtil::CreateR1<T>(
516 {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
517 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
518
519 auto input_literal =
520 client_->TransferToServer(input_r4).ConsumeValueOrDie();
521 auto filter_literal =
522 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
523
524 ComputeAndCompareLiteral(&builder, expected_r4,
525 {input_literal.get(), filter_literal.get()},
526 error_spec_);
527 }
528 };
529
530 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid,Types)531 TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); }
532
533 template <typename T>
534 class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
535 public:
RunTest()536 void RunTest() {
537 XlaBuilder builder(TestName());
538 std::vector<int64> input_dims = {1, 3, 3, 5};
539 std::vector<int64> filter_dims = {3, 3, 1, 15};
540 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
541 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
542 {
543 auto input = Parameter(&builder, 0, input_shape, "input");
544 auto filter = Parameter(&builder, 1, filter_shape, "filter");
545
546 // Tensorflow dimension numbers for 2D convolution.
547 ConvolutionDimensionNumbers dnums;
548 dnums.set_input_batch_dimension(0);
549 dnums.set_output_batch_dimension(0);
550 dnums.add_input_spatial_dimensions(1);
551 dnums.add_output_spatial_dimensions(1);
552 dnums.add_input_spatial_dimensions(2);
553 dnums.add_output_spatial_dimensions(2);
554 dnums.set_input_feature_dimension(3);
555 dnums.set_output_feature_dimension(3);
556 dnums.add_kernel_spatial_dimensions(0);
557 dnums.add_kernel_spatial_dimensions(1);
558 dnums.set_kernel_input_feature_dimension(2);
559 dnums.set_kernel_output_feature_dimension(3);
560
561 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
562 /*feature_group_count=*/5);
563 }
564
565 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
566 iota_int_init_value(input_elems, 1);
567 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
568 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
569
570 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
571 iota_int_init_value(filter_elems, 1);
572 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
573 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
574
575 auto expected_r1 = LiteralUtil::CreateR1<T>(
576 {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
577 static_cast<T>(17172), static_cast<T>(17370), static_cast<T>(17568),
578 static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
579 static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
580 static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
581 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
582
583 auto input_literal =
584 client_->TransferToServer(input_r4).ConsumeValueOrDie();
585 auto filter_literal =
586 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
587
588 ComputeAndCompareLiteral(&builder, expected_r4,
589 {input_literal.get(), filter_literal.get()},
590 error_spec_);
591 }
592 };
593
594 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid,Types)595 TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) {
596 this->RunTest();
597 }
598
599 template <typename T>
600 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest {
601 public:
RunTest()602 void RunTest() {
603 XlaBuilder builder(TestName());
604 std::vector<int64> input_dims = {1, 4, 4, 5};
605 std::vector<int64> filter_dims = {3, 3, 1, 5};
606 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
607 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
608 {
609 auto input = Parameter(&builder, 0, input_shape, "input");
610 auto filter = Parameter(&builder, 1, filter_shape, "filter");
611
612 // Tensorflow dimension numbers for 2D convolution.
613 ConvolutionDimensionNumbers dnums;
614 dnums.set_input_batch_dimension(0);
615 dnums.set_output_batch_dimension(0);
616 dnums.add_input_spatial_dimensions(1);
617 dnums.add_output_spatial_dimensions(1);
618 dnums.add_input_spatial_dimensions(2);
619 dnums.add_output_spatial_dimensions(2);
620 dnums.set_input_feature_dimension(3);
621 dnums.set_output_feature_dimension(3);
622 dnums.add_kernel_spatial_dimensions(0);
623 dnums.add_kernel_spatial_dimensions(1);
624 dnums.set_kernel_input_feature_dimension(2);
625 dnums.set_kernel_output_feature_dimension(3);
626
627 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
628 /*feature_group_count=*/5);
629 }
630
631 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
632 iota_int_init_value(input_elems, 1);
633 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
634 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
635
636 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
637 iota_int_init_value(filter_elems, 1);
638 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
639 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
640
641 auto expected_r1 = LiteralUtil::CreateR1<T>(
642 {static_cast<T>(6864), static_cast<T>(7296), static_cast<T>(7746),
643 static_cast<T>(8214), static_cast<T>(8700), static_cast<T>(7809),
644 static_cast<T>(8286), static_cast<T>(8781), static_cast<T>(9294),
645 static_cast<T>(9825), static_cast<T>(10644), static_cast<T>(11256),
646 static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
647 static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
648 static_cast<T>(13614), static_cast<T>(14325)});
649 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie();
650
651 auto input_literal =
652 client_->TransferToServer(input_r4).ConsumeValueOrDie();
653 auto filter_literal =
654 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
655
656 ComputeAndCompareLiteral(&builder, expected_r4,
657 {input_literal.get(), filter_literal.get()},
658 error_spec_);
659
660 auto filter_r = filter_r1.Reshape(filter_dims);
661 }
662 };
663
664 TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid,Types)665 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) {
666 this->RunTest();
667 }
668
669 template <typename T>
670 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest {
671 public:
RunTest()672 void RunTest() {
673 XlaBuilder builder(TestName());
674 std::vector<int64> input_dims = {1, 4, 4, 512};
675 std::vector<int64> filter_dims = {3, 3, 1, 512};
676 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
677 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
678 {
679 auto input = Parameter(&builder, 0, input_shape, "input");
680 auto filter = Parameter(&builder, 1, filter_shape, "filter");
681
682 // Tensorflow dimension numbers for 2D convolution.
683 ConvolutionDimensionNumbers dnums;
684 dnums.set_input_batch_dimension(0);
685 dnums.set_output_batch_dimension(0);
686 dnums.add_input_spatial_dimensions(1);
687 dnums.add_output_spatial_dimensions(1);
688 dnums.add_input_spatial_dimensions(2);
689 dnums.add_output_spatial_dimensions(2);
690 dnums.set_input_feature_dimension(3);
691 dnums.set_output_feature_dimension(3);
692 dnums.add_kernel_spatial_dimensions(0);
693 dnums.add_kernel_spatial_dimensions(1);
694 dnums.set_kernel_input_feature_dimension(2);
695 dnums.set_kernel_output_feature_dimension(3);
696
697 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
698 /*feature_group_count=*/512);
699 }
700
701 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
702 static_cast<T>(1));
703 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
704 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
705
706 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
707 static_cast<T>(2));
708 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
709 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
710
711 std::vector<T> output_elems(2048, static_cast<T>(18));
712
713 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
714 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie();
715
716 auto input_literal =
717 client_->TransferToServer(input_r4).ConsumeValueOrDie();
718 auto filter_literal =
719 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
720
721 ComputeAndCompareLiteral(&builder, expected_r4,
722 {input_literal.get(), filter_literal.get()},
723 error_spec_);
724 }
725 };
726
727 TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid,Types)728 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) {
729 this->RunTest();
730 }
731
732 template <typename T>
733 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes
734 : public ConvolutionTest {
735 public:
RunTest()736 void RunTest() {
737 XlaBuilder builder(TestName());
738 std::vector<int64> input_dims = {1, 4, 4, 512};
739 std::vector<int64> filter_dims = {3, 3, 1, 512};
740 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
741 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
742 {
743 auto input = Parameter(&builder, 0, input_shape, "input");
744 auto filter = Parameter(&builder, 1, filter_shape, "filter");
745
746 // Tensorflow dimension numbers for 2D convolution.
747 ConvolutionDimensionNumbers dnums;
748 dnums.set_input_batch_dimension(0);
749 dnums.set_output_batch_dimension(0);
750 dnums.add_input_spatial_dimensions(1);
751 dnums.add_output_spatial_dimensions(1);
752 dnums.add_input_spatial_dimensions(2);
753 dnums.add_output_spatial_dimensions(2);
754 dnums.set_input_feature_dimension(3);
755 dnums.set_output_feature_dimension(3);
756 dnums.add_kernel_spatial_dimensions(0);
757 dnums.add_kernel_spatial_dimensions(1);
758 dnums.set_kernel_input_feature_dimension(2);
759 dnums.set_kernel_output_feature_dimension(3);
760
761 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
762 /*feature_group_count=*/512);
763 }
764
765 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
766 static_cast<T>(1));
767 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
768 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
769
770 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
771 static_cast<T>(2));
772 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
773 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
774
775 std::vector<T> output_elems(2048, static_cast<T>(18));
776
777 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
778 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie();
779 auto expected_r4_relaid =
780 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
781
782 auto input_literal =
783 client_->TransferToServer(input_r4).ConsumeValueOrDie();
784 auto filter_literal =
785 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
786
787 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
788 {input_literal.get(), filter_literal.get()},
789 error_spec_, &expected_r4_relaid.shape());
790 }
791 };
792
793 TYPED_TEST_CASE(
794 Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
795 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,Types)796 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
797 Types) {
798 this->RunTest();
799 }
800
801 template <typename T>
802 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes
803 : public ConvolutionTest {
804 public:
RunTest()805 void RunTest() {
806 XlaBuilder builder(TestName());
807 std::vector<int64> input_dims = {256, 4, 4, 512};
808 std::vector<int64> filter_dims = {3, 3, 1, 512};
809 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
810 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
811 {
812 auto input = Parameter(&builder, 0, input_shape, "input");
813 auto filter = Parameter(&builder, 1, filter_shape, "filter");
814
815 // Tensorflow dimension numbers for 2D convolution.
816 ConvolutionDimensionNumbers dnums;
817 dnums.set_input_batch_dimension(0);
818 dnums.set_output_batch_dimension(0);
819 dnums.add_input_spatial_dimensions(1);
820 dnums.add_output_spatial_dimensions(1);
821 dnums.add_input_spatial_dimensions(2);
822 dnums.add_output_spatial_dimensions(2);
823 dnums.set_input_feature_dimension(3);
824 dnums.set_output_feature_dimension(3);
825 dnums.add_kernel_spatial_dimensions(0);
826 dnums.add_kernel_spatial_dimensions(1);
827 dnums.set_kernel_input_feature_dimension(2);
828 dnums.set_kernel_output_feature_dimension(3);
829
830 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
831 /*feature_group_count=*/512);
832 }
833
834 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
835 static_cast<T>(1));
836 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
837 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
838 auto input_r4_relaid =
839 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
840
841 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
842 static_cast<T>(2));
843 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
844 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
845
846 std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
847
848 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
849 auto expected_r4 =
850 expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie();
851
852 auto input_literal =
853 client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
854 auto filter_literal =
855 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
856
857 ComputeAndCompareLiteral(&builder, expected_r4,
858 {input_literal.get(), filter_literal.get()},
859 error_spec_);
860 }
861 };
862
863 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
864 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,Types)865 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
866 Types) {
867 this->RunTest();
868 }
869
870 template <typename T>
871 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes
872 : public ConvolutionTest {
873 public:
RunTest()874 void RunTest() {
875 XlaBuilder builder(TestName());
876 std::vector<int64> input_dims = {256, 4, 4, 512};
877 std::vector<int64> filter_dims = {3, 3, 1, 512};
878 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
879 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
880 {
881 auto input = Parameter(&builder, 0, input_shape, "input");
882 auto filter = Parameter(&builder, 1, filter_shape, "filter");
883
884 // Tensorflow dimension numbers for 2D convolution.
885 ConvolutionDimensionNumbers dnums;
886 dnums.set_input_batch_dimension(0);
887 dnums.set_output_batch_dimension(0);
888 dnums.add_input_spatial_dimensions(1);
889 dnums.add_output_spatial_dimensions(1);
890 dnums.add_input_spatial_dimensions(2);
891 dnums.add_output_spatial_dimensions(2);
892 dnums.set_input_feature_dimension(3);
893 dnums.set_output_feature_dimension(3);
894 dnums.add_kernel_spatial_dimensions(0);
895 dnums.add_kernel_spatial_dimensions(1);
896 dnums.set_kernel_input_feature_dimension(2);
897 dnums.set_kernel_output_feature_dimension(3);
898
899 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
900 /*feature_group_count=*/512);
901 }
902
903 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
904 static_cast<T>(1));
905 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
906 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
907 auto input_r4_relaid =
908 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
909
910 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
911 static_cast<T>(2));
912 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
913 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
914
915 std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
916
917 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
918 auto expected_r4 =
919 expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie();
920 auto expected_r4_relaid =
921 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
922
923 auto input_literal =
924 client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
925 auto filter_literal =
926 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
927
928 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
929 {input_literal.get(), filter_literal.get()},
930 error_spec_, &expected_r4_relaid.shape());
931 }
932 };
933
934 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
935 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,Types)936 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
937 Types) {
938 this->RunTest();
939 }
940
941 template <typename T>
942 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes
943 : public ConvolutionTest {
944 public:
RunTest()945 void RunTest() {
946 XlaBuilder builder(TestName());
947 std::vector<int64> input_dims = {1, 4, 4, 5};
948 std::vector<int64> filter_dims = {3, 3, 1, 5};
949 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
950 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
951 {
952 auto input = Parameter(&builder, 0, input_shape, "input");
953 auto filter = Parameter(&builder, 1, filter_shape, "filter");
954
955 // Tensorflow dimension numbers for 2D convolution.
956 ConvolutionDimensionNumbers dnums;
957 dnums.set_input_batch_dimension(0);
958 dnums.set_output_batch_dimension(0);
959 dnums.add_input_spatial_dimensions(1);
960 dnums.add_output_spatial_dimensions(1);
961 dnums.add_input_spatial_dimensions(2);
962 dnums.add_output_spatial_dimensions(2);
963 dnums.set_input_feature_dimension(3);
964 dnums.set_output_feature_dimension(3);
965 dnums.add_kernel_spatial_dimensions(0);
966 dnums.add_kernel_spatial_dimensions(1);
967 dnums.set_kernel_input_feature_dimension(2);
968 dnums.set_kernel_output_feature_dimension(3);
969
970 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
971 /*feature_group_count=*/5);
972 }
973
974 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
975 iota_int_init_value(input_elems, 1);
976 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
977 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
978 auto input_r4_relaid =
979 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
980
981 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
982 iota_int_init_value(filter_elems, 1);
983 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
984 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
985
986 auto expected_r1 = LiteralUtil::CreateR1<T>(
987 {static_cast<T>(6864), static_cast<T>(7296), static_cast<T>(7746),
988 static_cast<T>(8214), static_cast<T>(8700), static_cast<T>(7809),
989 static_cast<T>(8286), static_cast<T>(8781), static_cast<T>(9294),
990 static_cast<T>(9825), static_cast<T>(10644), static_cast<T>(11256),
991 static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
992 static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
993 static_cast<T>(13614), static_cast<T>(14325)});
994 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie();
995 auto expected_r4_relaid =
996 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
997
998 auto input_literal =
999 client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
1000 auto filter_literal =
1001 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1002
1003 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1004 {input_literal.get(), filter_literal.get()},
1005 error_spec_, &expected_r4_relaid.shape());
1006 }
1007 };
1008
1009 TYPED_TEST_CASE(
1010 Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
1011 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,Types)1012 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
1013 Types) {
1014 this->RunTest();
1015 }
1016
1017 template <typename T>
1018 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest {
1019 public:
RunTest()1020 void RunTest() {
1021 XlaBuilder builder(TestName());
1022 std::vector<int64> input_dims = {1, 4, 4, 160};
1023 std::vector<int64> filter_dims = {3, 3, 1, 160};
1024 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1025 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1026 {
1027 auto input = Parameter(&builder, 0, input_shape, "input");
1028 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1029
1030 // Tensorflow dimension numbers for 2D convolution.
1031 ConvolutionDimensionNumbers dnums;
1032 dnums.set_input_batch_dimension(0);
1033 dnums.set_output_batch_dimension(0);
1034 dnums.add_input_spatial_dimensions(1);
1035 dnums.add_output_spatial_dimensions(1);
1036 dnums.add_input_spatial_dimensions(2);
1037 dnums.add_output_spatial_dimensions(2);
1038 dnums.set_input_feature_dimension(3);
1039 dnums.set_output_feature_dimension(3);
1040 dnums.add_kernel_spatial_dimensions(0);
1041 dnums.add_kernel_spatial_dimensions(1);
1042 dnums.set_kernel_input_feature_dimension(2);
1043 dnums.set_kernel_output_feature_dimension(3);
1044
1045 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1046 /*feature_group_count=*/160);
1047 }
1048
1049 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1050 static_cast<T>(1));
1051 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1052 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1053
1054 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1055 static_cast<T>(2));
1056 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1057 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1058
1059 std::vector<T> output_elems(640, static_cast<T>(18));
1060
1061 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1062 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie();
1063
1064 auto input_literal =
1065 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1066 auto filter_literal =
1067 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1068
1069 ComputeAndCompareLiteral(&builder, expected_r4,
1070 {input_literal.get(), filter_literal.get()},
1071 error_spec_);
1072 }
1073 };
1074
1075 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid,Types)1076 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) {
1077 this->RunTest();
1078 }
1079
1080 template <typename T>
1081 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes
1082 : public ConvolutionTest {
1083 public:
RunTest()1084 void RunTest() {
1085 XlaBuilder builder(TestName());
1086 std::vector<int64> input_dims = {1, 4, 4, 160};
1087 std::vector<int64> filter_dims = {3, 3, 1, 160};
1088 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1089 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1090 {
1091 auto input = Parameter(&builder, 0, input_shape, "input");
1092 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1093
1094 // Tensorflow dimension numbers for 2D convolution.
1095 ConvolutionDimensionNumbers dnums;
1096 dnums.set_input_batch_dimension(0);
1097 dnums.set_output_batch_dimension(0);
1098 dnums.add_input_spatial_dimensions(1);
1099 dnums.add_output_spatial_dimensions(1);
1100 dnums.add_input_spatial_dimensions(2);
1101 dnums.add_output_spatial_dimensions(2);
1102 dnums.set_input_feature_dimension(3);
1103 dnums.set_output_feature_dimension(3);
1104 dnums.add_kernel_spatial_dimensions(0);
1105 dnums.add_kernel_spatial_dimensions(1);
1106 dnums.set_kernel_input_feature_dimension(2);
1107 dnums.set_kernel_output_feature_dimension(3);
1108
1109 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1110 /*feature_group_count=*/160);
1111 }
1112
1113 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1114 static_cast<T>(1));
1115 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1116 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1117 auto input_r4_relaid =
1118 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1119
1120 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1121 static_cast<T>(2));
1122 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1123 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1124
1125 std::vector<T> output_elems(640, static_cast<T>(18));
1126
1127 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1128 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie();
1129 auto expected_r4_relaid =
1130 expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1}));
1131
1132 auto input_literal =
1133 client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
1134 auto filter_literal =
1135 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1136
1137 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1138 {input_literal.get(), filter_literal.get()},
1139 error_spec_, &expected_r4_relaid.shape());
1140 }
1141 };
1142
1143 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
1144 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,Types)1145 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
1146 Types) {
1147 this->RunTest();
1148 }
1149
1150 template <typename T>
1151 class Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes
1152 : public ConvolutionTest {
1153 public:
RunTest()1154 void RunTest() {
1155 XlaBuilder builder(TestName());
1156 std::vector<int64> input_dims = {1, 4, 4, 160};
1157 std::vector<int64> filter_dims = {3, 3, 1, 160};
1158 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1159 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1160 {
1161 auto input = Parameter(&builder, 0, input_shape, "input");
1162 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1163
1164 // Tensorflow dimension numbers for 2D convolution.
1165 ConvolutionDimensionNumbers dnums;
1166 dnums.set_input_batch_dimension(0);
1167 dnums.set_output_batch_dimension(0);
1168 dnums.add_input_spatial_dimensions(1);
1169 dnums.add_output_spatial_dimensions(1);
1170 dnums.add_input_spatial_dimensions(2);
1171 dnums.add_output_spatial_dimensions(2);
1172 dnums.set_input_feature_dimension(3);
1173 dnums.set_output_feature_dimension(3);
1174 dnums.add_kernel_spatial_dimensions(0);
1175 dnums.add_kernel_spatial_dimensions(1);
1176 dnums.set_kernel_input_feature_dimension(2);
1177 dnums.set_kernel_output_feature_dimension(3);
1178
1179 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1180 /*feature_group_count=*/160);
1181 }
1182
1183 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1184 static_cast<T>(1));
1185 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1186 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1187 auto input_r4_relaid =
1188 input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1189
1190 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1191 static_cast<T>(2));
1192 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1193 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1194
1195 std::vector<T> output_elems(640, static_cast<T>(18));
1196
1197 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1198 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie();
1199 auto expected_r4_relaid =
1200 expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1201
1202 auto input_literal =
1203 client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
1204 auto filter_literal =
1205 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1206
1207 ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1208 {input_literal.get(), filter_literal.get()},
1209 error_spec_, &expected_r4_relaid.shape());
1210 }
1211 };
1212
1213 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes,
1214 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes,Types)1215 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes,
1216 Types) {
1217 this->RunTest();
1218 }
1219
1220 template <typename T>
1221 class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid
1222 : public ConvolutionTest {
1223 public:
RunTest()1224 void RunTest() {
1225 XlaBuilder builder(TestName());
1226 std::vector<int64> input_dims = {1, 4, 4, 1024};
1227 std::vector<int64> filter_dims = {3, 3, 1, 1024};
1228 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1229 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1230 {
1231 auto input = Parameter(&builder, 0, input_shape, "input");
1232 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1233
1234 // Tensorflow dimension numbers for 2D convolution.
1235 ConvolutionDimensionNumbers dnums;
1236 dnums.set_input_batch_dimension(0);
1237 dnums.set_output_batch_dimension(0);
1238 dnums.add_input_spatial_dimensions(1);
1239 dnums.add_output_spatial_dimensions(1);
1240 dnums.add_input_spatial_dimensions(2);
1241 dnums.add_output_spatial_dimensions(2);
1242 dnums.set_input_feature_dimension(3);
1243 dnums.set_output_feature_dimension(3);
1244 dnums.add_kernel_spatial_dimensions(0);
1245 dnums.add_kernel_spatial_dimensions(1);
1246 dnums.set_kernel_input_feature_dimension(2);
1247 dnums.set_kernel_output_feature_dimension(3);
1248
1249 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1250 /*feature_group_count=*/1024);
1251 }
1252
1253 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1254 static_cast<T>(1));
1255 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1256 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1257
1258 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1259 static_cast<T>(2));
1260 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1261 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1262
1263 std::vector<T> output_elems(4096, static_cast<T>(18));
1264
1265 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1266 auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie();
1267
1268 auto input_literal =
1269 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1270 auto filter_literal =
1271 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1272
1273 ComputeAndCompareLiteral(&builder, expected_r4,
1274 {input_literal.get(), filter_literal.get()},
1275 error_spec_);
1276 }
1277 };
1278
1279 TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid,Types)1280 TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) {
1281 this->RunTest();
1282 }
1283
1284 template <typename T>
1285 class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest {
1286 public:
RunTest()1287 void RunTest() {
1288 XlaBuilder builder(TestName());
1289 std::vector<int64> input_dims = {1, 2, 2, 6};
1290 std::vector<int64> filter_dims = {2, 2, 2, 12};
1291 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1292 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1293 {
1294 auto input = Parameter(&builder, 0, input_shape, "input");
1295 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1296
1297 // Tensorflow dimension numbers for 2D convolution.
1298 ConvolutionDimensionNumbers dnums;
1299 dnums.set_input_batch_dimension(0);
1300 dnums.set_output_batch_dimension(0);
1301 dnums.add_input_spatial_dimensions(1);
1302 dnums.add_output_spatial_dimensions(1);
1303 dnums.add_input_spatial_dimensions(2);
1304 dnums.add_output_spatial_dimensions(2);
1305 dnums.set_input_feature_dimension(3);
1306 dnums.set_output_feature_dimension(3);
1307 dnums.add_kernel_spatial_dimensions(0);
1308 dnums.add_kernel_spatial_dimensions(1);
1309 dnums.set_kernel_input_feature_dimension(2);
1310 dnums.set_kernel_output_feature_dimension(3);
1311
1312 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1313 /*feature_group_count=*/3);
1314 }
1315
1316 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1317 iota_int_init_value(input_elems, 1);
1318 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1319 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1320
1321 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1322 iota_int_init_value(filter_elems, 1);
1323 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1324 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1325
1326 auto expected_r1 = LiteralUtil::CreateR1<T>(
1327 {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
1328 static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
1329 static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
1330 static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
1331 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
1332
1333 auto input_literal =
1334 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1335 auto filter_literal =
1336 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1337
1338 ComputeAndCompareLiteral(&builder, expected_r4,
1339 {input_literal.get(), filter_literal.get()},
1340 error_spec_);
1341 }
1342 };
1343
1344 TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid,Types)1345 TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) {
1346 this->RunTest();
1347 }
1348
1349 template <typename T>
1350 class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest {
1351 public:
RunTest()1352 void RunTest() {
1353 XlaBuilder builder(TestName());
1354 std::vector<int64> input_dims = {1, 2, 2, 1024};
1355 std::vector<int64> filter_dims = {2, 2, 128, 512};
1356 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1357 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1358 {
1359 auto input = Parameter(&builder, 0, input_shape, "input");
1360 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1361
1362 // Tensorflow dimension numbers for 2D convolution.
1363 ConvolutionDimensionNumbers dnums;
1364 dnums.set_input_batch_dimension(0);
1365 dnums.set_output_batch_dimension(0);
1366 dnums.add_input_spatial_dimensions(1);
1367 dnums.add_output_spatial_dimensions(1);
1368 dnums.add_input_spatial_dimensions(2);
1369 dnums.add_output_spatial_dimensions(2);
1370 dnums.set_input_feature_dimension(3);
1371 dnums.set_output_feature_dimension(3);
1372 dnums.add_kernel_spatial_dimensions(0);
1373 dnums.add_kernel_spatial_dimensions(1);
1374 dnums.set_kernel_input_feature_dimension(2);
1375 dnums.set_kernel_output_feature_dimension(3);
1376
1377 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1378 /*feature_group_count=*/8);
1379 }
1380
1381 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1382 static_cast<T>(1));
1383
1384 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1385 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1386
1387 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1388 static_cast<T>(2));
1389
1390 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1391 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1392
1393 std::vector<T> output_elems(512, static_cast<T>(1024));
1394 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1395 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).ConsumeValueOrDie();
1396
1397 auto input_literal =
1398 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1399 auto filter_literal =
1400 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1401
1402 ComputeAndCompareLiteral(&builder, expected_r4,
1403 {input_literal.get(), filter_literal.get()},
1404 error_spec_);
1405 }
1406 };
1407
1408 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid,Types)1409 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) {
1410 this->RunTest();
1411 }
1412
1413 template <typename T>
1414 class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest {
1415 public:
RunTest()1416 void RunTest() {
1417 XlaBuilder builder(TestName());
1418 std::vector<int64> input_dims = {1, 2, 2, 1024};
1419 std::vector<int64> filter_dims = {2, 2, 128, 8};
1420 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1421 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1422 {
1423 auto input = Parameter(&builder, 0, input_shape, "input");
1424 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1425
1426 // Tensorflow dimension numbers for 2D convolution.
1427 ConvolutionDimensionNumbers dnums;
1428 dnums.set_input_batch_dimension(0);
1429 dnums.set_output_batch_dimension(0);
1430 dnums.add_input_spatial_dimensions(1);
1431 dnums.add_output_spatial_dimensions(1);
1432 dnums.add_input_spatial_dimensions(2);
1433 dnums.add_output_spatial_dimensions(2);
1434 dnums.set_input_feature_dimension(3);
1435 dnums.set_output_feature_dimension(3);
1436 dnums.add_kernel_spatial_dimensions(0);
1437 dnums.add_kernel_spatial_dimensions(1);
1438 dnums.set_kernel_input_feature_dimension(2);
1439 dnums.set_kernel_output_feature_dimension(3);
1440
1441 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1442 /*feature_group_count=*/8);
1443 }
1444
1445 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1446 static_cast<T>(1));
1447
1448 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1449 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1450
1451 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1452 static_cast<T>(2));
1453
1454 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1455 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1456
1457 std::vector<T> output_elems(8, static_cast<T>(1024));
1458 auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1459 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).ConsumeValueOrDie();
1460
1461 auto input_literal =
1462 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1463 auto filter_literal =
1464 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1465
1466 ComputeAndCompareLiteral(&builder, expected_r4,
1467 {input_literal.get(), filter_literal.get()},
1468 error_spec_);
1469 }
1470 };
1471
1472 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid,Types)1473 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) {
1474 this->RunTest();
1475 }
1476
1477 template <typename T>
1478 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest {
1479 public:
RunTest()1480 void RunTest() {
1481 XlaBuilder builder(TestName());
1482 std::vector<int64> input_dims = {1, 2, 2, 12};
1483 std::vector<int64> filter_dims = {2, 2, 3, 4};
1484 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1485 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1486 {
1487 auto input = Parameter(&builder, 0, input_shape, "input");
1488 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1489
1490 // Tensorflow dimension numbers for 2D convolution.
1491 ConvolutionDimensionNumbers dnums;
1492 dnums.set_input_batch_dimension(0);
1493 dnums.set_output_batch_dimension(0);
1494 dnums.add_input_spatial_dimensions(1);
1495 dnums.add_output_spatial_dimensions(1);
1496 dnums.add_input_spatial_dimensions(2);
1497 dnums.add_output_spatial_dimensions(2);
1498 dnums.set_input_feature_dimension(3);
1499 dnums.set_output_feature_dimension(3);
1500 dnums.add_kernel_spatial_dimensions(0);
1501 dnums.add_kernel_spatial_dimensions(1);
1502 dnums.set_kernel_input_feature_dimension(2);
1503 dnums.set_kernel_output_feature_dimension(3);
1504
1505 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1506 /*feature_group_count=*/4);
1507 }
1508
1509 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1510 iota_int_init_value(input_elems, 1);
1511 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1512 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1513
1514 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1515 iota_int_init_value(filter_elems, 1);
1516 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1517 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1518
1519 auto expected_r1 =
1520 LiteralUtil::CreateR1<T>({static_cast<T>(7712), static_cast<T>(8816),
1521 static_cast<T>(9992), static_cast<T>(11240)});
1522 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie();
1523
1524 auto input_literal =
1525 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1526 auto filter_literal =
1527 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1528
1529 ComputeAndCompareLiteral(&builder, expected_r4,
1530 {input_literal.get(), filter_literal.get()},
1531 error_spec_);
1532 }
1533 };
1534
1535 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid,Types)1536 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) {
1537 this->RunTest();
1538 }
1539
1540 template <typename T>
1541 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes
1542 : public ConvolutionTest {
1543 public:
RunTest()1544 void RunTest() {
1545 XlaBuilder builder(TestName());
1546 std::vector<int64> input_dims = {1, 2, 2, 12};
1547 std::vector<int64> filter_dims = {2, 2, 4, 3};
1548 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1549 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1550 {
1551 auto input = Parameter(&builder, 0, input_shape, "input");
1552 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1553
1554 // Tensorflow dimension numbers for 2D convolution.
1555 ConvolutionDimensionNumbers dnums;
1556 dnums.set_input_batch_dimension(0);
1557 dnums.set_output_batch_dimension(0);
1558 dnums.add_input_spatial_dimensions(1);
1559 dnums.add_output_spatial_dimensions(1);
1560 dnums.add_input_spatial_dimensions(2);
1561 dnums.add_output_spatial_dimensions(2);
1562 dnums.set_input_feature_dimension(3);
1563 dnums.set_output_feature_dimension(3);
1564 dnums.add_kernel_spatial_dimensions(0);
1565 dnums.add_kernel_spatial_dimensions(1);
1566 dnums.set_kernel_input_feature_dimension(3);
1567 dnums.set_kernel_output_feature_dimension(2);
1568
1569 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1570 /*feature_group_count=*/4);
1571 }
1572
1573 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1574 iota_int_init_value(input_elems, 1);
1575 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1576 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1577
1578 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1579 iota_int_init_value(filter_elems, 1);
1580 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1581 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1582 auto filter_r4_relaid =
1583 filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
1584 auto expected_r1 = LiteralUtil::CreateR1<T>(
1585 {static_cast<T>(6968), static_cast<T>(8516), static_cast<T>(10280),
1586 static_cast<T>(12260)});
1587 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie();
1588
1589 auto input_literal =
1590 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1591 auto filter_literal =
1592 client_->TransferToServer(filter_r4_relaid).ConsumeValueOrDie();
1593
1594 ComputeAndCompareLiteral(&builder, expected_r4,
1595 {input_literal.get(), filter_literal.get()},
1596 error_spec_);
1597 }
1598 };
1599
1600 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1601 TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,Types)1602 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1603 Types) {
1604 this->RunTest();
1605 }
1606
1607 template <typename T>
1608 class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest {
1609 public:
RunTest()1610 void RunTest() {
1611 XlaBuilder builder(TestName());
1612 std::vector<int64> input_dims = {1, 1, 1, 12};
1613 std::vector<int64> filter_dims = {1, 1, 3, 4};
1614 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1615 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1616 {
1617 auto input = Parameter(&builder, 0, input_shape, "input");
1618 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1619
1620 // Tensorflow dimension numbers for 2D convolution.
1621 ConvolutionDimensionNumbers dnums;
1622 dnums.set_input_batch_dimension(0);
1623 dnums.set_output_batch_dimension(0);
1624 dnums.add_input_spatial_dimensions(1);
1625 dnums.add_output_spatial_dimensions(1);
1626 dnums.add_input_spatial_dimensions(2);
1627 dnums.add_output_spatial_dimensions(2);
1628 dnums.set_input_feature_dimension(3);
1629 dnums.set_output_feature_dimension(3);
1630 dnums.add_kernel_spatial_dimensions(0);
1631 dnums.add_kernel_spatial_dimensions(1);
1632 dnums.set_kernel_input_feature_dimension(2);
1633 dnums.set_kernel_output_feature_dimension(3);
1634
1635 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1636 /*feature_group_count=*/4);
1637 }
1638
1639 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1640 iota_int_init_value(input_elems, 1);
1641 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1642 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1643
1644 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1645 iota_int_init_value(filter_elems, 1);
1646 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1647 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1648
1649 auto expected_r1 =
1650 LiteralUtil::CreateR1<T>({static_cast<T>(38), static_cast<T>(98),
1651 static_cast<T>(176), static_cast<T>(272)});
1652 auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie();
1653
1654 auto input_literal =
1655 client_->TransferToServer(input_r4).ConsumeValueOrDie();
1656 auto filter_literal =
1657 client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1658
1659 ComputeAndCompareLiteral(&builder, expected_r4,
1660 {input_literal.get(), filter_literal.get()},
1661 error_spec_);
1662 }
1663 };
1664
1665 TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid,Types)1666 TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) {
1667 this->RunTest();
1668 }
1669
1670 // Test fixture to run convolution tests with and without convolution
1671 // canonicalization enabled.
1672 class ConvolveWithAndWithoutCanonicalization
1673 : public ConvolutionTest,
1674 public ::testing::WithParamInterface<bool> {};
1675
XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,DISABLED_ON_GPU (Convolve2D_NoSpatialDims))1676 XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
1677 DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) {
1678 if (GetParam()) {
1679 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1680 "convolution-canonicalization");
1681 }
1682 XlaBuilder builder(TestName());
1683 Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
1684 Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
1685
1686 auto input = Parameter(&builder, 0, input_shape, "input");
1687 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1688
1689 ConvolutionDimensionNumbers dnums;
1690 dnums.set_input_feature_dimension(0);
1691 dnums.set_input_batch_dimension(1);
1692 dnums.set_kernel_input_feature_dimension(0);
1693 dnums.set_kernel_output_feature_dimension(1);
1694 dnums.set_output_batch_dimension(0);
1695 dnums.set_output_feature_dimension(1);
1696 ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
1697
1698 Array2D<float> param0(4, 29);
1699 param0.FillUnique();
1700
1701 Array2D<float> param1(4, 10);
1702 param1.FillUnique();
1703
1704 Array2D<float> expected_result(29, 10);
1705 expected_result.Fill(0);
1706
1707 ComputeAndCompare(&builder,
1708 {LiteralUtil::CreateFromArray(param0),
1709 LiteralUtil::CreateFromArray(param1)},
1710 error_spec_);
1711 }
1712
1713 INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
1714 ConvolveWithAndWithoutCanonicalization,
1715 ::testing::Values(true, false));
1716
1717 struct Convolve1DTestParam {
1718 int64 input_feature;
1719 int64 output_feature;
1720 int64 batch;
1721 int64 window_size;
1722 int64 num_windows;
1723 };
1724
1725 class Convolve1D1WindowTestBase
1726 : public ConvolutionTest,
1727 public ::testing::WithParamInterface<Convolve1DTestParam> {
1728 protected:
1729 template <typename T>
TestImpl()1730 void TestImpl() {
1731 XlaBuilder builder(TestName());
1732 int64 input_feature = GetParam().input_feature;
1733 int64 output_feature = GetParam().output_feature;
1734 int64 batch = GetParam().batch;
1735 int64 num_windows = GetParam().num_windows;
1736 int64 window_size = GetParam().window_size;
1737 std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
1738 input_feature};
1739 std::vector<int64> filter_dims = {window_size, input_feature,
1740 output_feature};
1741 Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1742 Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1743 {
1744 auto input = Parameter(&builder, 0, input_shape, "input");
1745 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1746
1747 // Tensorflow dimension numbers for 1D convolution.
1748 ConvolutionDimensionNumbers dnums;
1749 dnums.set_input_batch_dimension(0);
1750 dnums.set_output_batch_dimension(0);
1751 dnums.add_input_spatial_dimensions(1);
1752 dnums.add_output_spatial_dimensions(1);
1753 dnums.set_input_feature_dimension(2);
1754 dnums.set_output_feature_dimension(2);
1755 dnums.add_kernel_spatial_dimensions(0);
1756 dnums.set_kernel_input_feature_dimension(1);
1757 dnums.set_kernel_output_feature_dimension(2);
1758
1759 ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
1760 }
1761
1762 std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1763 static_cast<T>(1.0f));
1764 auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1765 auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1766
1767 std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1768 static_cast<T>(1.0f));
1769
1770 auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1771 auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1772
1773 std::vector<T> expect_elems(batch * output_feature * num_windows,
1774 static_cast<T>(window_size * input_feature));
1775 auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
1776 auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
1777 .ConsumeValueOrDie();
1778
1779 auto input_literal =
1780 client_->TransferToServer(input_r3).ConsumeValueOrDie();
1781 auto filter_literal =
1782 client_->TransferToServer(filter_r3).ConsumeValueOrDie();
1783 ComputeAndCompareLiteral(&builder, expected_r3,
1784 {input_literal.get(), filter_literal.get()},
1785 error_spec_);
1786 }
1787 };
1788
1789 class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
1790
XLA_TEST_P(Convolve1D1WindowTestFloat,Convolve1D1Window)1791 XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
1792
1793 INSTANTIATE_TEST_CASE_P(
1794 Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
1795 ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
1796 Convolve1DTestParam{160, 1, 1, 5, 1},
1797 Convolve1DTestParam{24, 1, 1, 20, 1},
1798 Convolve1DTestParam{30, 1, 1, 20, 1},
1799 Convolve1DTestParam{23, 1, 1, 20, 20},
1800 Convolve1DTestParam{25, 1, 1, 20, 1},
1801 Convolve1DTestParam{24, 1, 1, 10, 5},
1802 Convolve1DTestParam{160, 1, 1, 10, 1},
1803 Convolve1DTestParam{255, 1, 1, 3, 1},
1804 Convolve1DTestParam{130, 1, 1, 1, 3},
1805 Convolve1DTestParam{64, 1, 1, 1, 1},
1806 Convolve1DTestParam{128, 1, 1, 1, 1},
1807 Convolve1DTestParam{139, 1, 1, 128, 1},
1808 Convolve1DTestParam{1, 10, 10, 1, 10},
1809 Convolve1DTestParam{1, 10, 130, 1, 2},
1810 Convolve1DTestParam{1, 10, 130, 1, 1},
1811 Convolve1DTestParam{1, 64, 64, 1, 10},
1812 Convolve1DTestParam{1, 65, 65, 1, 1},
1813 Convolve1DTestParam{1, 128, 128, 1, 1},
1814 Convolve1DTestParam{128, 128, 128, 128, 1},
1815 Convolve1DTestParam{1, 128, 128, 1, 1},
1816 Convolve1DTestParam{2, 2, 2, 2, 1},
1817 Convolve1DTestParam{161, 1, 1, 10, 1},
1818 Convolve1DTestParam{900, 1, 1, 10, 1},
1819 Convolve1DTestParam{640, 3, 3, 128, 1})
1820
1821 );
1822
1823 #if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
1824 class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
1825
XLA_TEST_P(Convolve1D1WindowTestHalf,Convolve1D1Window)1826 XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
1827 TestImpl<Eigen::half>();
1828 }
1829
1830 INSTANTIATE_TEST_CASE_P(
1831 Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
1832 ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
1833 Convolve1DTestParam{160, 1, 1, 5, 1},
1834 Convolve1DTestParam{24, 1, 1, 20, 1},
1835 Convolve1DTestParam{30, 1, 1, 20, 1},
1836 Convolve1DTestParam{23, 1, 1, 20, 20},
1837 Convolve1DTestParam{25, 1, 1, 20, 1},
1838 Convolve1DTestParam{24, 1, 1, 10, 5},
1839 Convolve1DTestParam{160, 1, 1, 10, 1},
1840 Convolve1DTestParam{255, 1, 1, 3, 1},
1841 Convolve1DTestParam{130, 1, 1, 1, 3},
1842 Convolve1DTestParam{64, 1, 1, 1, 1},
1843 Convolve1DTestParam{128, 1, 1, 1, 1},
1844 // TODO(b/72566306): The following five tests failed on CPU with unreasonable
1845 // relative errors. Last ran on 2018-02-22.
1846 #if XLA_TEST_BACKEND_GPU
1847 Convolve1DTestParam{139, 1, 1, 128, 1},
1848 Convolve1DTestParam{640, 3, 3, 128, 1},
1849 Convolve1DTestParam{900, 1, 1, 10, 1},
1850 Convolve1DTestParam{1, 10, 10, 1, 10},
1851 Convolve1DTestParam{1, 10, 130, 1, 1},
1852 #endif
1853 Convolve1DTestParam{1, 10, 130, 1, 2},
1854 Convolve1DTestParam{1, 64, 64, 1, 10},
1855 Convolve1DTestParam{1, 65, 65, 1, 1},
1856 Convolve1DTestParam{1, 128, 128, 1, 1},
1857 Convolve1DTestParam{128, 128, 128, 128, 1},
1858 Convolve1DTestParam{1, 128, 128, 1, 1},
1859 Convolve1DTestParam{2, 2, 2, 2, 1},
1860 Convolve1DTestParam{161, 1, 1, 10, 1})
1861
1862 );
1863 #endif
1864
XLA_TEST_F(ConvolutionTest,Convolve_bf16_1x1x1x2_1x1x1x2_Valid)1865 XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
1866 XlaBuilder builder(TestName());
1867 Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1868 Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1869 auto input = Parameter(&builder, 0, input_shape, "input");
1870 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1871 Conv(input, filter, {1, 1}, Padding::kValid);
1872
1873 Array4D<bfloat16> input_data(1, 1, 1, 2);
1874 input_data.FillWithYX(Array2D<bfloat16>({
1875 {bfloat16(1), bfloat16(2)},
1876 }));
1877 Array4D<bfloat16> filter_data(1, 1, 1, 2);
1878 filter_data.FillWithYX(Array2D<bfloat16>({
1879 {bfloat16(5), bfloat16(6)},
1880 }));
1881
1882 ComputeAndCompare(&builder,
1883 {LiteralUtil::CreateFromArray(input_data),
1884 LiteralUtil::CreateFromArray(filter_data)},
1885 error_spec_);
1886 }
1887
1888 // Check that GPU convs still work if the CudnnAlgorithmPicker pass is disabled.
1889 // (We run this test on all platforms, because, what the heck.)
XLA_TEST_F(ConvolutionTest,NoCudnnAlgorithmPicker)1890 XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
1891 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1892 "cudnn-conv-algorithm-picker");
1893
1894 XlaBuilder builder(TestName());
1895 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1896 Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1897 auto input = Parameter(&builder, 0, input_shape, "input");
1898 auto filter = Parameter(&builder, 1, filter_shape, "filter");
1899 Conv(input, filter, {1, 1}, Padding::kValid);
1900
1901 Array4D<float> input_data(1, 1, 1, 2);
1902 input_data.FillIota(0);
1903 Array4D<float> filter_data(1, 1, 1, 2);
1904 filter_data.FillIota(10);
1905
1906 ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
1907 LiteralUtil::CreateFromArray(filter_data)});
1908 }
1909
XLA_TEST_F(ConvolutionTest,ConvolveF32BackwardInputGroupedConvolution)1910 XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
1911 XlaBuilder builder(TestName());
1912 Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
1913 Array4D<float> input_data(1, 64, 100, 100);
1914 input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321);
1915 Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
1916 Array4D<float> filter_data(7, 7, 1, 64);
1917 input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320);
1918 auto input = Parameter(&builder, 0, input_shape, "input");
1919 auto filter = ConstantR4FromArray4D(&builder, filter_data);
1920
1921 // Specify bf01_01io->bf01 as dimension numbers.
1922 ConvolutionDimensionNumbers dnums;
1923 // Input
1924 dnums.set_input_feature_dimension(1);
1925 dnums.set_input_batch_dimension(0);
1926 dnums.add_input_spatial_dimensions(2);
1927 dnums.add_input_spatial_dimensions(3);
1928 // Kernel
1929 dnums.set_kernel_input_feature_dimension(2);
1930 dnums.set_kernel_output_feature_dimension(3);
1931 dnums.add_kernel_spatial_dimensions(0);
1932 dnums.add_kernel_spatial_dimensions(1);
1933 // Output
1934 dnums.set_output_batch_dimension(0);
1935 dnums.set_output_feature_dimension(1);
1936 dnums.add_output_spatial_dimensions(2);
1937 dnums.add_output_spatial_dimensions(3);
1938 ConvGeneral(input, filter, /*window_strides=*/{1, 1},
1939 /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
1940 /*feature_group_count=*/64);
1941
1942 ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
1943 error_spec_);
1944 }
1945
1946 class ConvolutionHloTest : public HloTestBase {};
1947
XLA_TEST_F(ConvolutionHloTest,ConvolveF64Forward)1948 XLA_TEST_F(ConvolutionHloTest, ConvolveF64Forward) {
1949 constexpr char kHlo[] = R"(
1950 HloModule TestModule
1951
1952 ENTRY Test {
1953 %arg0 = f64[3,56,56,16] parameter(0)
1954 %arg1 = f64[3,3,3,64] parameter(1)
1955 ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
1956 })";
1957 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1958 }
1959
XLA_TEST_F(ConvolutionHloTest,ConvolveF32ForwardReversed)1960 XLA_TEST_F(ConvolutionHloTest, ConvolveF32ForwardReversed) {
1961 constexpr char kHlo[] = R"(
1962 HloModule TestModule
1963
1964 ENTRY Test {
1965 %arg0 = f32[3,56,56,16] parameter(0)
1966 %arg1 = f32[3,3,3,32] parameter(1)
1967 ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf
1968 })";
1969 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1970 }
1971
XLA_TEST_F(ConvolutionHloTest,ConvolveF64BackwardFilter)1972 XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardFilter) {
1973 constexpr char kHlo[] = R"(
1974 HloModule TestModule
1975
1976 ENTRY Test {
1977 %arg0 = f64[2,5,8,1] parameter(0)
1978 %arg1 = f64[2,5,8,2] parameter(1)
1979 ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf
1980 })";
1981 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1982 }
1983
XLA_TEST_F(ConvolutionHloTest,ConvolveF64BackwardInput)1984 XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardInput) {
1985 constexpr char kHlo[] = R"(
1986 HloModule TestModule
1987
1988 ENTRY Test {
1989 %output = f64[4,5,16,16] parameter(0)
1990 %kernel = f64[5,3,7,7] parameter(1)
1991 %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3}
1992 ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01
1993 })";
1994 EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1995 }
1996
1997 } // namespace
1998 } // namespace xla
1999