• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests of convolution with trivial kernels and no special variations (like
17 // strides and padding).
18 
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/array2d.h"
24 #include "tensorflow/compiler/xla/array4d.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/client/padding.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/literal.h"
31 #include "tensorflow/compiler/xla/reference_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace xla {
43 namespace {
44 
45 class ConvolutionTest : public ClientLibraryTestBase {
46  protected:
47 #if XLA_TEST_BACKEND_GPU
48   // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
49   // convolution. So relax the absolute error threshold.
50   ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4);
51 #else
52   ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4);
53 #endif
54 };
55 
56 #ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
57 using TestTypes = ::testing::Types<float>;
58 #else
59 using TestTypes = ::testing::Types<float, Eigen::half>;
60 #endif
61 
62 template <typename T>
63 class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
64  public:
RunTest()65   void RunTest() {
66     const int kInputActivationSizeY = 3;
67     const int kInputActivationSizeX = 3;
68     const int kInputActivationSizeZ = 256;
69     const int kKernelSizeX = 2;
70     const int kKernelSizeY = 2;
71     const int kOutputActivationSizeZ = 256;
72     const int kMiniBatchSize = 4;
73     auto alhs = absl::make_unique<Array4D<T>>(
74         kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY,
75         kInputActivationSizeX);
76     alhs->FillWithMultiples(static_cast<T>(1.0f));
77     ASSERT_EQ(3, alhs->width());
78     ASSERT_EQ(3, alhs->height());
79 
80     auto arhs = absl::make_unique<Array4D<T>>(kOutputActivationSizeZ,
81                                               kInputActivationSizeZ,
82                                               kKernelSizeY, kKernelSizeX);
83     Array2D<T> rhs_raster({
84         {1.0f, 0.0f},  // row 0
85         {0.0f, 0.0f},  // row 1
86     });
87     arhs->FillWithYX(rhs_raster);
88     ASSERT_EQ(2, arhs->width());
89     ASSERT_EQ(2, arhs->height());
90 
91     XlaBuilder builder(TestName());
92     auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
93     auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
94     PrecisionConfig precision;
95     // The left hand side of the convolution is numbers between 0 and 2304 which
96     // requires at least 11 mantissa bits and the DEFAULT precision config is
97     // allowed to round to bfloat16 which only has 7 mantissa bits.
98     precision.add_operand_precision(PrecisionConfig::HIGHEST);
99     precision.add_operand_precision(PrecisionConfig::DEFAULT);
100     Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
101          /*batch_group_count=*/1, &precision);
102 
103     ComputeAndCompare(&builder, {}, error_spec_);
104   }
105 };
106 
107 TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota,Types)108 XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
109   this->RunTest();
110 }
111 
112 template <typename T>
113 class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
114  public:
RunTest()115   void RunTest() {
116     XlaBuilder builder(TestName());
117     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
118     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
119     auto input = Parameter(&builder, 0, input_shape, "input");
120     auto filter = Parameter(&builder, 1, filter_shape, "filter");
121     Conv(input, filter, {1, 1}, Padding::kValid);
122 
123     Array4D<T> input_data(1, 1, 1, 2);
124     input_data.FillWithYX(Array2D<T>({
125         {1.0f, 2.0f},
126     }));
127     Array4D<T> filter_data(1, 1, 1, 2);
128     filter_data.FillWithYX(Array2D<T>({
129         {5.0f, 6.0f},
130     }));
131 
132     ComputeAndCompare(&builder,
133                       {LiteralUtil::CreateFromArray(input_data),
134                        LiteralUtil::CreateFromArray(filter_data)},
135                       error_spec_);
136   }
137 };
138 
139 TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid,Types)140 TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
141 
142 // Tests valid padding for 2D convolution in raster space.
143 template <typename T>
144 class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
145  public:
RunTest()146   void RunTest() {
147     XlaBuilder builder(TestName());
148     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
149     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
150     auto input = Parameter(&builder, 0, input_shape, "input");
151     auto filter = Parameter(&builder, 1, filter_shape, "filter");
152     Conv(input, filter, {1, 1}, Padding::kValid);
153 
154     Array4D<T> input_data(1, 1, 4, 4);
155     input_data.FillWithYX(Array2D<T>({
156         {1.0f, 2.0f, 3.0f, 4.0f},
157         {5.0f, 6.0f, 7.0f, 8.0f},
158         {9.0f, 10.0f, 11.0f, 12.0f},
159         {13.0f, 14.0f, 15.0f, 16.0f},
160     }));
161     Array4D<T> filter_data(1, 1, 2, 2);
162     filter_data.FillWithYX(Array2D<T>({
163         {5.0f, 6.0f},
164         {7.0f, 8.0f},
165     }));
166     ComputeAndCompare(&builder,
167                       {LiteralUtil::CreateFromArray(input_data),
168                        LiteralUtil::CreateFromArray(filter_data)},
169                       error_spec_);
170   }
171 };
172 
173 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid,Types)174 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
175 
176 // Tests same padding for 2D convolution in raster space.
177 template <typename T>
178 class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
179  public:
RunTest()180   void RunTest() {
181     XlaBuilder builder(TestName());
182     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
183     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
184     auto input = Parameter(&builder, 0, input_shape, "input");
185     auto filter = Parameter(&builder, 1, filter_shape, "filter");
186     Conv(input, filter, {1, 1}, Padding::kSame);
187 
188     Array4D<T> input_data(1, 1, 4, 4);
189     input_data.FillWithYX(Array2D<T>({
190         {1.0f, 2.0f, 3.0f, 4.0f},
191         {5.0f, 6.0f, 7.0f, 8.0f},
192         {9.0f, 10.0f, 11.0f, 12.0f},
193         {13.0f, 14.0f, 15.0f, 16.0f},
194     }));
195     Array4D<T> filter_data(1, 1, 2, 2);
196     filter_data.FillWithYX(Array2D<T>({
197         {5.0f, 6.0f},
198         {7.0f, 8.0f},
199     }));
200 
201     ComputeAndCompare(&builder,
202                       {LiteralUtil::CreateFromArray(input_data),
203                        LiteralUtil::CreateFromArray(filter_data)},
204                       error_spec_);
205   }
206 };
207 
208 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same,Types)209 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
210 
211 // Tests same padding for 2D convolution in raster space with an odd sized
212 // kernel.
213 template <typename T>
214 class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
215  public:
RunTest()216   void RunTest() {
217     XlaBuilder builder(TestName());
218     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
219     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3});
220     auto input = Parameter(&builder, 0, input_shape, "input");
221     auto filter = Parameter(&builder, 1, filter_shape, "filter");
222     Conv(input, filter, {1, 1}, Padding::kSame);
223 
224     Array4D<T> input_data(1, 1, 4, 4);
225     input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
226                                       {5.0f, 6.0f, 7.0f, 8.0f},
227                                       {9.0f, 10.0f, 11.0f, 12.0f},
228                                       {13.0f, 14.0f, 15.0f, 16.0f}}));
229     Array4D<T> filter_data(1, 1, 3, 3);
230     filter_data.FillWithYX(Array2D<T>(
231         {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
232     // clang-format on
233     ComputeAndCompare(&builder,
234                       {LiteralUtil::CreateFromArray(input_data),
235                        LiteralUtil::CreateFromArray(filter_data)},
236                       error_spec_);
237   }
238 };
239 
240 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same,Types)241 TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
242 
XLA_TEST_F(ConvolutionTest,Convolve1D_1x2x5_1x2x2_Valid)243 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
244   XlaBuilder builder(TestName());
245   {
246     Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
247     Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
248     auto input = Parameter(&builder, 0, input_shape, "input");
249     auto filter = Parameter(&builder, 1, filter_shape, "filter");
250     Conv(input, filter, {1}, Padding::kValid);
251   }
252 
253   Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
254   Array3D<float> filter({{{10, 20}, {30, 40}}});
255 
256   Array3D<float> expected({{{510, 610, 710, 810}}});
257 
258   auto input_literal =
259       client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
260           .ConsumeValueOrDie();
261   auto filter_literal =
262       client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
263           .ConsumeValueOrDie();
264 
265   ComputeAndCompareR3<float>(&builder, expected,
266                              {input_literal.get(), filter_literal.get()},
267                              error_spec_);
268 }
269 
270 template <typename T>
271 class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
272  public:
RunTest()273   void RunTest() {
274     XlaBuilder builder(TestName());
275     {
276       Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
277       Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
278       auto input = Parameter(&builder, 0, input_shape, "input");
279       auto filter = Parameter(&builder, 1, filter_shape, "filter");
280       // Convolution dimensions are bf0_oi0->bo0.
281       ConvGeneralDilated(
282           input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
283           /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
284           /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
285     }
286 
287     Array3D<T> input(
288         {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
289     Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
290 
291     Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
292 
293     auto input_literal =
294         client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
295             .ConsumeValueOrDie();
296     auto filter_literal =
297         client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
298             .ConsumeValueOrDie();
299 
300     ComputeAndCompareR3<T>(&builder, expected,
301                            {input_literal.get(), filter_literal.get()},
302                            error_spec_);
303   }
304 };  // namespace
305 
306 TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation,Types)307 TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
308 
XLA_TEST_F(ConvolutionTest,Convolve1D_1x2x5_1x2x2_WithLHSDilation)309 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
310   XlaBuilder builder(TestName());
311   {
312     Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
313     Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
314     auto input = Parameter(&builder, 0, input_shape, "input");
315     auto filter = Parameter(&builder, 1, filter_shape, "filter");
316     // Convolution dimensions are bf0_oi0->bo0.
317     ConvGeneralDilated(
318         input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
319         /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
320         /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
321   }
322 
323   Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
324   Array3D<float> filter({{{10, 20}, {30, 40}}});
325 
326   Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
327 
328   auto input_literal =
329       client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
330           .ConsumeValueOrDie();
331   auto filter_literal =
332       client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
333           .ConsumeValueOrDie();
334 
335   ComputeAndCompareR3<float>(&builder, expected,
336                              {input_literal.get(), filter_literal.get()},
337                              error_spec_);
338 }
339 
XLA_TEST_F(ConvolutionTest,Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation)340 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
341   XlaBuilder builder(TestName());
342   {
343     Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
344     Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
345     auto input = Parameter(&builder, 0, input_shape, "input");
346     auto filter = Parameter(&builder, 1, filter_shape, "filter");
347     // Convolution dimensions are bf0_oi0->bo0.
348     ConvGeneralDilated(
349         input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
350         /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
351         /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
352   }
353 
354   Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
355   Array3D<float> filter({{{10, 20}, {30, 40}}});
356 
357   Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
358 
359   auto input_literal =
360       client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
361           .ConsumeValueOrDie();
362   auto filter_literal =
363       client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
364           .ConsumeValueOrDie();
365 
366   ComputeAndCompareR3<float>(&builder, expected,
367                              {input_literal.get(), filter_literal.get()},
368                              error_spec_);
369 }
370 
371 template <typename T>
372 class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
373  public:
RunTest()374   void RunTest() {
375     XlaBuilder builder(TestName());
376     {
377       Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
378       Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
379       auto input = Parameter(&builder, 0, input_shape, "input");
380       auto filter = Parameter(&builder, 1, filter_shape, "filter");
381       // Convolution dimensions are bf0_oi0->bo0.
382       ConvGeneralDilated(
383           input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
384           /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
385           /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
386     }
387 
388     Array3D<T> input(
389         {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
390     Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
391 
392     Array3D<T> expected(
393         {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
394 
395     auto input_literal =
396         client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
397             .ConsumeValueOrDie();
398     auto filter_literal =
399         client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
400             .ConsumeValueOrDie();
401 
402     ComputeAndCompareR3<T>(&builder, expected,
403                            {input_literal.get(), filter_literal.get()},
404                            error_spec_);
405   }
406 };
407 
408 TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding,Types)409 TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
410 
XLA_TEST_F(ConvolutionTest,Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)411 XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
412   XlaBuilder builder(TestName());
413   std::vector<int64> input_dims = {1, 4, 2, 3, 3};
414   std::vector<int64> filter_dims = {2, 2, 2, 3, 3};
415   Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
416   Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
417   {
418     auto input = Parameter(&builder, 0, input_shape, "input");
419     auto filter = Parameter(&builder, 1, filter_shape, "filter");
420 
421     // Tensorflow dimension numbers for 3D convolution.
422     ConvolutionDimensionNumbers dnums;
423     dnums.set_input_batch_dimension(0);
424     dnums.set_output_batch_dimension(0);
425     dnums.add_input_spatial_dimensions(1);
426     dnums.add_output_spatial_dimensions(1);
427     dnums.add_input_spatial_dimensions(2);
428     dnums.add_output_spatial_dimensions(2);
429     dnums.add_input_spatial_dimensions(3);
430     dnums.add_output_spatial_dimensions(3);
431     dnums.set_input_feature_dimension(4);
432     dnums.set_output_feature_dimension(4);
433     dnums.add_kernel_spatial_dimensions(0);
434     dnums.add_kernel_spatial_dimensions(1);
435     dnums.add_kernel_spatial_dimensions(2);
436     dnums.set_kernel_input_feature_dimension(3);
437     dnums.set_kernel_output_feature_dimension(4);
438 
439     ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums);
440   }
441 
442   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
443   iota(input_elems.begin(), input_elems.end(), 1.0f);
444   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
445   auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
446 
447   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
448   iota(filter_elems.begin(), filter_elems.end(), 1.0f);
449   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
450   auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
451 
452   auto expected_r1 = LiteralUtil::CreateR1<float>(
453       {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
454        38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
455   auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
456 
457   auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
458   auto filter_literal =
459       client_->TransferToServer(filter_r5).ConsumeValueOrDie();
460 
461   ComputeAndCompareLiteral(&builder, expected_r5,
462                            {input_literal.get(), filter_literal.get()},
463                            error_spec_);
464 }
465 
466 // std::iota doesn't work when init_value has a type Eigen::half in some build
467 // servers. The error message is missing the operator ++.
468 template <typename T>
iota_int_init_value(std::vector<T> & values,int init_value)469 void iota_int_init_value(std::vector<T>& values, int init_value) {
470   absl::c_for_each(values,
471                    [&](T& value) { value = static_cast<T>(init_value++); });
472 }
473 
474 template <typename T>
475 class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
476  public:
RunTest()477   void RunTest() {
478     XlaBuilder builder(TestName());
479     std::vector<int64> input_dims = {1, 3, 3, 5};
480     std::vector<int64> filter_dims = {3, 3, 5, 3};
481     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
482     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
483     {
484       auto input = Parameter(&builder, 0, input_shape, "input");
485       auto filter = Parameter(&builder, 1, filter_shape, "filter");
486 
487       // Tensorflow dimension numbers for 2D convolution.
488       ConvolutionDimensionNumbers dnums;
489       dnums.set_input_batch_dimension(0);
490       dnums.set_output_batch_dimension(0);
491       dnums.add_input_spatial_dimensions(1);
492       dnums.add_output_spatial_dimensions(1);
493       dnums.add_input_spatial_dimensions(2);
494       dnums.add_output_spatial_dimensions(2);
495       dnums.set_input_feature_dimension(3);
496       dnums.set_output_feature_dimension(3);
497       dnums.add_kernel_spatial_dimensions(0);
498       dnums.add_kernel_spatial_dimensions(1);
499       dnums.set_kernel_input_feature_dimension(2);
500       dnums.set_kernel_output_feature_dimension(3);
501 
502       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums);
503     }
504 
505     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
506     iota_int_init_value(input_elems, 1);
507     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
508     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
509 
510     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
511     iota_int_init_value(filter_elems, 1);
512     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
513     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
514 
515     auto expected_r1 = LiteralUtil::CreateR1<T>(
516         {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
517     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
518 
519     auto input_literal =
520         client_->TransferToServer(input_r4).ConsumeValueOrDie();
521     auto filter_literal =
522         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
523 
524     ComputeAndCompareLiteral(&builder, expected_r4,
525                              {input_literal.get(), filter_literal.get()},
526                              error_spec_);
527   }
528 };
529 
530 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid,Types)531 TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); }
532 
533 template <typename T>
534 class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
535  public:
RunTest()536   void RunTest() {
537     XlaBuilder builder(TestName());
538     std::vector<int64> input_dims = {1, 3, 3, 5};
539     std::vector<int64> filter_dims = {3, 3, 1, 15};
540     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
541     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
542     {
543       auto input = Parameter(&builder, 0, input_shape, "input");
544       auto filter = Parameter(&builder, 1, filter_shape, "filter");
545 
546       // Tensorflow dimension numbers for 2D convolution.
547       ConvolutionDimensionNumbers dnums;
548       dnums.set_input_batch_dimension(0);
549       dnums.set_output_batch_dimension(0);
550       dnums.add_input_spatial_dimensions(1);
551       dnums.add_output_spatial_dimensions(1);
552       dnums.add_input_spatial_dimensions(2);
553       dnums.add_output_spatial_dimensions(2);
554       dnums.set_input_feature_dimension(3);
555       dnums.set_output_feature_dimension(3);
556       dnums.add_kernel_spatial_dimensions(0);
557       dnums.add_kernel_spatial_dimensions(1);
558       dnums.set_kernel_input_feature_dimension(2);
559       dnums.set_kernel_output_feature_dimension(3);
560 
561       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
562                                 /*feature_group_count=*/5);
563     }
564 
565     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
566     iota_int_init_value(input_elems, 1);
567     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
568     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
569 
570     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
571     iota_int_init_value(filter_elems, 1);
572     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
573     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
574 
575     auto expected_r1 = LiteralUtil::CreateR1<T>(
576         {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
577          static_cast<T>(17172), static_cast<T>(17370), static_cast<T>(17568),
578          static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
579          static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
580          static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
581     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
582 
583     auto input_literal =
584         client_->TransferToServer(input_r4).ConsumeValueOrDie();
585     auto filter_literal =
586         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
587 
588     ComputeAndCompareLiteral(&builder, expected_r4,
589                              {input_literal.get(), filter_literal.get()},
590                              error_spec_);
591   }
592 };
593 
594 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid,Types)595 TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) {
596   this->RunTest();
597 }
598 
599 template <typename T>
600 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest {
601  public:
RunTest()602   void RunTest() {
603     XlaBuilder builder(TestName());
604     std::vector<int64> input_dims = {1, 4, 4, 5};
605     std::vector<int64> filter_dims = {3, 3, 1, 5};
606     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
607     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
608     {
609       auto input = Parameter(&builder, 0, input_shape, "input");
610       auto filter = Parameter(&builder, 1, filter_shape, "filter");
611 
612       // Tensorflow dimension numbers for 2D convolution.
613       ConvolutionDimensionNumbers dnums;
614       dnums.set_input_batch_dimension(0);
615       dnums.set_output_batch_dimension(0);
616       dnums.add_input_spatial_dimensions(1);
617       dnums.add_output_spatial_dimensions(1);
618       dnums.add_input_spatial_dimensions(2);
619       dnums.add_output_spatial_dimensions(2);
620       dnums.set_input_feature_dimension(3);
621       dnums.set_output_feature_dimension(3);
622       dnums.add_kernel_spatial_dimensions(0);
623       dnums.add_kernel_spatial_dimensions(1);
624       dnums.set_kernel_input_feature_dimension(2);
625       dnums.set_kernel_output_feature_dimension(3);
626 
627       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
628                                 /*feature_group_count=*/5);
629     }
630 
631     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
632     iota_int_init_value(input_elems, 1);
633     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
634     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
635 
636     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
637     iota_int_init_value(filter_elems, 1);
638     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
639     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
640 
641     auto expected_r1 = LiteralUtil::CreateR1<T>(
642         {static_cast<T>(6864),  static_cast<T>(7296),  static_cast<T>(7746),
643          static_cast<T>(8214),  static_cast<T>(8700),  static_cast<T>(7809),
644          static_cast<T>(8286),  static_cast<T>(8781),  static_cast<T>(9294),
645          static_cast<T>(9825),  static_cast<T>(10644), static_cast<T>(11256),
646          static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
647          static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
648          static_cast<T>(13614), static_cast<T>(14325)});
649     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie();
650 
651     auto input_literal =
652         client_->TransferToServer(input_r4).ConsumeValueOrDie();
653     auto filter_literal =
654         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
655 
656     ComputeAndCompareLiteral(&builder, expected_r4,
657                              {input_literal.get(), filter_literal.get()},
658                              error_spec_);
659 
660     auto filter_r = filter_r1.Reshape(filter_dims);
661   }
662 };
663 
664 TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid,Types)665 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) {
666   this->RunTest();
667 }
668 
669 template <typename T>
670 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest {
671  public:
RunTest()672   void RunTest() {
673     XlaBuilder builder(TestName());
674     std::vector<int64> input_dims = {1, 4, 4, 512};
675     std::vector<int64> filter_dims = {3, 3, 1, 512};
676     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
677     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
678     {
679       auto input = Parameter(&builder, 0, input_shape, "input");
680       auto filter = Parameter(&builder, 1, filter_shape, "filter");
681 
682       // Tensorflow dimension numbers for 2D convolution.
683       ConvolutionDimensionNumbers dnums;
684       dnums.set_input_batch_dimension(0);
685       dnums.set_output_batch_dimension(0);
686       dnums.add_input_spatial_dimensions(1);
687       dnums.add_output_spatial_dimensions(1);
688       dnums.add_input_spatial_dimensions(2);
689       dnums.add_output_spatial_dimensions(2);
690       dnums.set_input_feature_dimension(3);
691       dnums.set_output_feature_dimension(3);
692       dnums.add_kernel_spatial_dimensions(0);
693       dnums.add_kernel_spatial_dimensions(1);
694       dnums.set_kernel_input_feature_dimension(2);
695       dnums.set_kernel_output_feature_dimension(3);
696 
697       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
698                                 /*feature_group_count=*/512);
699     }
700 
701     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
702                                static_cast<T>(1));
703     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
704     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
705 
706     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
707                                 static_cast<T>(2));
708     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
709     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
710 
711     std::vector<T> output_elems(2048, static_cast<T>(18));
712 
713     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
714     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie();
715 
716     auto input_literal =
717         client_->TransferToServer(input_r4).ConsumeValueOrDie();
718     auto filter_literal =
719         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
720 
721     ComputeAndCompareLiteral(&builder, expected_r4,
722                              {input_literal.get(), filter_literal.get()},
723                              error_spec_);
724   }
725 };
726 
727 TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid,Types)728 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) {
729   this->RunTest();
730 }
731 
732 template <typename T>
733 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes
734     : public ConvolutionTest {
735  public:
RunTest()736   void RunTest() {
737     XlaBuilder builder(TestName());
738     std::vector<int64> input_dims = {1, 4, 4, 512};
739     std::vector<int64> filter_dims = {3, 3, 1, 512};
740     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
741     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
742     {
743       auto input = Parameter(&builder, 0, input_shape, "input");
744       auto filter = Parameter(&builder, 1, filter_shape, "filter");
745 
746       // Tensorflow dimension numbers for 2D convolution.
747       ConvolutionDimensionNumbers dnums;
748       dnums.set_input_batch_dimension(0);
749       dnums.set_output_batch_dimension(0);
750       dnums.add_input_spatial_dimensions(1);
751       dnums.add_output_spatial_dimensions(1);
752       dnums.add_input_spatial_dimensions(2);
753       dnums.add_output_spatial_dimensions(2);
754       dnums.set_input_feature_dimension(3);
755       dnums.set_output_feature_dimension(3);
756       dnums.add_kernel_spatial_dimensions(0);
757       dnums.add_kernel_spatial_dimensions(1);
758       dnums.set_kernel_input_feature_dimension(2);
759       dnums.set_kernel_output_feature_dimension(3);
760 
761       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
762                                 /*feature_group_count=*/512);
763     }
764 
765     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
766                                static_cast<T>(1));
767     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
768     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
769 
770     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
771                                 static_cast<T>(2));
772     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
773     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
774 
775     std::vector<T> output_elems(2048, static_cast<T>(18));
776 
777     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
778     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie();
779     auto expected_r4_relaid =
780         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
781 
782     auto input_literal =
783         client_->TransferToServer(input_r4).ConsumeValueOrDie();
784     auto filter_literal =
785         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
786 
787     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
788                              {input_literal.get(), filter_literal.get()},
789                              error_spec_, &expected_r4_relaid.shape());
790   }
791 };
792 
793 TYPED_TEST_CASE(
794     Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
795     TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,Types)796 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
797            Types) {
798   this->RunTest();
799 }
800 
801 template <typename T>
802 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes
803     : public ConvolutionTest {
804  public:
RunTest()805   void RunTest() {
806     XlaBuilder builder(TestName());
807     std::vector<int64> input_dims = {256, 4, 4, 512};
808     std::vector<int64> filter_dims = {3, 3, 1, 512};
809     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
810     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
811     {
812       auto input = Parameter(&builder, 0, input_shape, "input");
813       auto filter = Parameter(&builder, 1, filter_shape, "filter");
814 
815       // Tensorflow dimension numbers for 2D convolution.
816       ConvolutionDimensionNumbers dnums;
817       dnums.set_input_batch_dimension(0);
818       dnums.set_output_batch_dimension(0);
819       dnums.add_input_spatial_dimensions(1);
820       dnums.add_output_spatial_dimensions(1);
821       dnums.add_input_spatial_dimensions(2);
822       dnums.add_output_spatial_dimensions(2);
823       dnums.set_input_feature_dimension(3);
824       dnums.set_output_feature_dimension(3);
825       dnums.add_kernel_spatial_dimensions(0);
826       dnums.add_kernel_spatial_dimensions(1);
827       dnums.set_kernel_input_feature_dimension(2);
828       dnums.set_kernel_output_feature_dimension(3);
829 
830       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
831                                 /*feature_group_count=*/512);
832     }
833 
834     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
835                                static_cast<T>(1));
836     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
837     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
838     auto input_r4_relaid =
839         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
840 
841     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
842                                 static_cast<T>(2));
843     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
844     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
845 
846     std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
847 
848     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
849     auto expected_r4 =
850         expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie();
851 
852     auto input_literal =
853         client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
854     auto filter_literal =
855         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
856 
857     ComputeAndCompareLiteral(&builder, expected_r4,
858                              {input_literal.get(), filter_literal.get()},
859                              error_spec_);
860   }
861 };
862 
863 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
864                 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,Types)865 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
866            Types) {
867   this->RunTest();
868 }
869 
870 template <typename T>
871 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes
872     : public ConvolutionTest {
873  public:
RunTest()874   void RunTest() {
875     XlaBuilder builder(TestName());
876     std::vector<int64> input_dims = {256, 4, 4, 512};
877     std::vector<int64> filter_dims = {3, 3, 1, 512};
878     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
879     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
880     {
881       auto input = Parameter(&builder, 0, input_shape, "input");
882       auto filter = Parameter(&builder, 1, filter_shape, "filter");
883 
884       // Tensorflow dimension numbers for 2D convolution.
885       ConvolutionDimensionNumbers dnums;
886       dnums.set_input_batch_dimension(0);
887       dnums.set_output_batch_dimension(0);
888       dnums.add_input_spatial_dimensions(1);
889       dnums.add_output_spatial_dimensions(1);
890       dnums.add_input_spatial_dimensions(2);
891       dnums.add_output_spatial_dimensions(2);
892       dnums.set_input_feature_dimension(3);
893       dnums.set_output_feature_dimension(3);
894       dnums.add_kernel_spatial_dimensions(0);
895       dnums.add_kernel_spatial_dimensions(1);
896       dnums.set_kernel_input_feature_dimension(2);
897       dnums.set_kernel_output_feature_dimension(3);
898 
899       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
900                                 /*feature_group_count=*/512);
901     }
902 
903     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
904                                static_cast<T>(1));
905     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
906     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
907     auto input_r4_relaid =
908         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
909 
910     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
911                                 static_cast<T>(2));
912     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
913     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
914 
915     std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
916 
917     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
918     auto expected_r4 =
919         expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie();
920     auto expected_r4_relaid =
921         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
922 
923     auto input_literal =
924         client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
925     auto filter_literal =
926         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
927 
928     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
929                              {input_literal.get(), filter_literal.get()},
930                              error_spec_, &expected_r4_relaid.shape());
931   }
932 };
933 
934 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
935                 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,Types)936 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
937            Types) {
938   this->RunTest();
939 }
940 
941 template <typename T>
942 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes
943     : public ConvolutionTest {
944  public:
RunTest()945   void RunTest() {
946     XlaBuilder builder(TestName());
947     std::vector<int64> input_dims = {1, 4, 4, 5};
948     std::vector<int64> filter_dims = {3, 3, 1, 5};
949     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
950     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
951     {
952       auto input = Parameter(&builder, 0, input_shape, "input");
953       auto filter = Parameter(&builder, 1, filter_shape, "filter");
954 
955       // Tensorflow dimension numbers for 2D convolution.
956       ConvolutionDimensionNumbers dnums;
957       dnums.set_input_batch_dimension(0);
958       dnums.set_output_batch_dimension(0);
959       dnums.add_input_spatial_dimensions(1);
960       dnums.add_output_spatial_dimensions(1);
961       dnums.add_input_spatial_dimensions(2);
962       dnums.add_output_spatial_dimensions(2);
963       dnums.set_input_feature_dimension(3);
964       dnums.set_output_feature_dimension(3);
965       dnums.add_kernel_spatial_dimensions(0);
966       dnums.add_kernel_spatial_dimensions(1);
967       dnums.set_kernel_input_feature_dimension(2);
968       dnums.set_kernel_output_feature_dimension(3);
969 
970       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
971                                 /*feature_group_count=*/5);
972     }
973 
974     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
975     iota_int_init_value(input_elems, 1);
976     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
977     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
978     auto input_r4_relaid =
979         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
980 
981     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
982     iota_int_init_value(filter_elems, 1);
983     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
984     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
985 
986     auto expected_r1 = LiteralUtil::CreateR1<T>(
987         {static_cast<T>(6864),  static_cast<T>(7296),  static_cast<T>(7746),
988          static_cast<T>(8214),  static_cast<T>(8700),  static_cast<T>(7809),
989          static_cast<T>(8286),  static_cast<T>(8781),  static_cast<T>(9294),
990          static_cast<T>(9825),  static_cast<T>(10644), static_cast<T>(11256),
991          static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
992          static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
993          static_cast<T>(13614), static_cast<T>(14325)});
994     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie();
995     auto expected_r4_relaid =
996         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
997 
998     auto input_literal =
999         client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
1000     auto filter_literal =
1001         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1002 
1003     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1004                              {input_literal.get(), filter_literal.get()},
1005                              error_spec_, &expected_r4_relaid.shape());
1006   }
1007 };
1008 
1009 TYPED_TEST_CASE(
1010     Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
1011     TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,Types)1012 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
1013            Types) {
1014   this->RunTest();
1015 }
1016 
1017 template <typename T>
1018 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest {
1019  public:
RunTest()1020   void RunTest() {
1021     XlaBuilder builder(TestName());
1022     std::vector<int64> input_dims = {1, 4, 4, 160};
1023     std::vector<int64> filter_dims = {3, 3, 1, 160};
1024     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1025     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1026     {
1027       auto input = Parameter(&builder, 0, input_shape, "input");
1028       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1029 
1030       // Tensorflow dimension numbers for 2D convolution.
1031       ConvolutionDimensionNumbers dnums;
1032       dnums.set_input_batch_dimension(0);
1033       dnums.set_output_batch_dimension(0);
1034       dnums.add_input_spatial_dimensions(1);
1035       dnums.add_output_spatial_dimensions(1);
1036       dnums.add_input_spatial_dimensions(2);
1037       dnums.add_output_spatial_dimensions(2);
1038       dnums.set_input_feature_dimension(3);
1039       dnums.set_output_feature_dimension(3);
1040       dnums.add_kernel_spatial_dimensions(0);
1041       dnums.add_kernel_spatial_dimensions(1);
1042       dnums.set_kernel_input_feature_dimension(2);
1043       dnums.set_kernel_output_feature_dimension(3);
1044 
1045       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1046                                 /*feature_group_count=*/160);
1047     }
1048 
1049     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1050                                static_cast<T>(1));
1051     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1052     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1053 
1054     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1055                                 static_cast<T>(2));
1056     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1057     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1058 
1059     std::vector<T> output_elems(640, static_cast<T>(18));
1060 
1061     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1062     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie();
1063 
1064     auto input_literal =
1065         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1066     auto filter_literal =
1067         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1068 
1069     ComputeAndCompareLiteral(&builder, expected_r4,
1070                              {input_literal.get(), filter_literal.get()},
1071                              error_spec_);
1072   }
1073 };
1074 
1075 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid,Types)1076 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) {
1077   this->RunTest();
1078 }
1079 
1080 template <typename T>
1081 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes
1082     : public ConvolutionTest {
1083  public:
RunTest()1084   void RunTest() {
1085     XlaBuilder builder(TestName());
1086     std::vector<int64> input_dims = {1, 4, 4, 160};
1087     std::vector<int64> filter_dims = {3, 3, 1, 160};
1088     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1089     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1090     {
1091       auto input = Parameter(&builder, 0, input_shape, "input");
1092       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1093 
1094       // Tensorflow dimension numbers for 2D convolution.
1095       ConvolutionDimensionNumbers dnums;
1096       dnums.set_input_batch_dimension(0);
1097       dnums.set_output_batch_dimension(0);
1098       dnums.add_input_spatial_dimensions(1);
1099       dnums.add_output_spatial_dimensions(1);
1100       dnums.add_input_spatial_dimensions(2);
1101       dnums.add_output_spatial_dimensions(2);
1102       dnums.set_input_feature_dimension(3);
1103       dnums.set_output_feature_dimension(3);
1104       dnums.add_kernel_spatial_dimensions(0);
1105       dnums.add_kernel_spatial_dimensions(1);
1106       dnums.set_kernel_input_feature_dimension(2);
1107       dnums.set_kernel_output_feature_dimension(3);
1108 
1109       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1110                                 /*feature_group_count=*/160);
1111     }
1112 
1113     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1114                                static_cast<T>(1));
1115     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1116     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1117     auto input_r4_relaid =
1118         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1119 
1120     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1121                                 static_cast<T>(2));
1122     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1123     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1124 
1125     std::vector<T> output_elems(640, static_cast<T>(18));
1126 
1127     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1128     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie();
1129     auto expected_r4_relaid =
1130         expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1}));
1131 
1132     auto input_literal =
1133         client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
1134     auto filter_literal =
1135         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1136 
1137     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1138                              {input_literal.get(), filter_literal.get()},
1139                              error_spec_, &expected_r4_relaid.shape());
1140   }
1141 };
1142 
1143 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
1144                 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,Types)1145 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
1146            Types) {
1147   this->RunTest();
1148 }
1149 
1150 template <typename T>
1151 class Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes
1152     : public ConvolutionTest {
1153  public:
RunTest()1154   void RunTest() {
1155     XlaBuilder builder(TestName());
1156     std::vector<int64> input_dims = {1, 4, 4, 160};
1157     std::vector<int64> filter_dims = {3, 3, 1, 160};
1158     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1159     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1160     {
1161       auto input = Parameter(&builder, 0, input_shape, "input");
1162       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1163 
1164       // Tensorflow dimension numbers for 2D convolution.
1165       ConvolutionDimensionNumbers dnums;
1166       dnums.set_input_batch_dimension(0);
1167       dnums.set_output_batch_dimension(0);
1168       dnums.add_input_spatial_dimensions(1);
1169       dnums.add_output_spatial_dimensions(1);
1170       dnums.add_input_spatial_dimensions(2);
1171       dnums.add_output_spatial_dimensions(2);
1172       dnums.set_input_feature_dimension(3);
1173       dnums.set_output_feature_dimension(3);
1174       dnums.add_kernel_spatial_dimensions(0);
1175       dnums.add_kernel_spatial_dimensions(1);
1176       dnums.set_kernel_input_feature_dimension(2);
1177       dnums.set_kernel_output_feature_dimension(3);
1178 
1179       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1180                                 /*feature_group_count=*/160);
1181     }
1182 
1183     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1184                                static_cast<T>(1));
1185     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1186     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1187     auto input_r4_relaid =
1188         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1189 
1190     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1191                                 static_cast<T>(2));
1192     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1193     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1194 
1195     std::vector<T> output_elems(640, static_cast<T>(18));
1196 
1197     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1198     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie();
1199     auto expected_r4_relaid =
1200         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1201 
1202     auto input_literal =
1203         client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie();
1204     auto filter_literal =
1205         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1206 
1207     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1208                              {input_literal.get(), filter_literal.get()},
1209                              error_spec_, &expected_r4_relaid.shape());
1210   }
1211 };
1212 
1213 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes,
1214                 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes,Types)1215 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes,
1216            Types) {
1217   this->RunTest();
1218 }
1219 
1220 template <typename T>
1221 class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid
1222     : public ConvolutionTest {
1223  public:
RunTest()1224   void RunTest() {
1225     XlaBuilder builder(TestName());
1226     std::vector<int64> input_dims = {1, 4, 4, 1024};
1227     std::vector<int64> filter_dims = {3, 3, 1, 1024};
1228     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1229     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1230     {
1231       auto input = Parameter(&builder, 0, input_shape, "input");
1232       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1233 
1234       // Tensorflow dimension numbers for 2D convolution.
1235       ConvolutionDimensionNumbers dnums;
1236       dnums.set_input_batch_dimension(0);
1237       dnums.set_output_batch_dimension(0);
1238       dnums.add_input_spatial_dimensions(1);
1239       dnums.add_output_spatial_dimensions(1);
1240       dnums.add_input_spatial_dimensions(2);
1241       dnums.add_output_spatial_dimensions(2);
1242       dnums.set_input_feature_dimension(3);
1243       dnums.set_output_feature_dimension(3);
1244       dnums.add_kernel_spatial_dimensions(0);
1245       dnums.add_kernel_spatial_dimensions(1);
1246       dnums.set_kernel_input_feature_dimension(2);
1247       dnums.set_kernel_output_feature_dimension(3);
1248 
1249       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1250                                 /*feature_group_count=*/1024);
1251     }
1252 
1253     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1254                                static_cast<T>(1));
1255     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1256     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1257 
1258     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1259                                 static_cast<T>(2));
1260     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1261     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1262 
1263     std::vector<T> output_elems(4096, static_cast<T>(18));
1264 
1265     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1266     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie();
1267 
1268     auto input_literal =
1269         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1270     auto filter_literal =
1271         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1272 
1273     ComputeAndCompareLiteral(&builder, expected_r4,
1274                              {input_literal.get(), filter_literal.get()},
1275                              error_spec_);
1276   }
1277 };
1278 
1279 TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid,Types)1280 TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) {
1281   this->RunTest();
1282 }
1283 
1284 template <typename T>
1285 class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest {
1286  public:
RunTest()1287   void RunTest() {
1288     XlaBuilder builder(TestName());
1289     std::vector<int64> input_dims = {1, 2, 2, 6};
1290     std::vector<int64> filter_dims = {2, 2, 2, 12};
1291     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1292     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1293     {
1294       auto input = Parameter(&builder, 0, input_shape, "input");
1295       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1296 
1297       // Tensorflow dimension numbers for 2D convolution.
1298       ConvolutionDimensionNumbers dnums;
1299       dnums.set_input_batch_dimension(0);
1300       dnums.set_output_batch_dimension(0);
1301       dnums.add_input_spatial_dimensions(1);
1302       dnums.add_output_spatial_dimensions(1);
1303       dnums.add_input_spatial_dimensions(2);
1304       dnums.add_output_spatial_dimensions(2);
1305       dnums.set_input_feature_dimension(3);
1306       dnums.set_output_feature_dimension(3);
1307       dnums.add_kernel_spatial_dimensions(0);
1308       dnums.add_kernel_spatial_dimensions(1);
1309       dnums.set_kernel_input_feature_dimension(2);
1310       dnums.set_kernel_output_feature_dimension(3);
1311 
1312       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1313                                 /*feature_group_count=*/3);
1314     }
1315 
1316     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1317     iota_int_init_value(input_elems, 1);
1318     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1319     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1320 
1321     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1322     iota_int_init_value(filter_elems, 1);
1323     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1324     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1325 
1326     auto expected_r1 = LiteralUtil::CreateR1<T>(
1327         {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
1328          static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
1329          static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
1330          static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
1331     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
1332 
1333     auto input_literal =
1334         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1335     auto filter_literal =
1336         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1337 
1338     ComputeAndCompareLiteral(&builder, expected_r4,
1339                              {input_literal.get(), filter_literal.get()},
1340                              error_spec_);
1341   }
1342 };
1343 
1344 TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid,Types)1345 TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) {
1346   this->RunTest();
1347 }
1348 
1349 template <typename T>
1350 class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest {
1351  public:
RunTest()1352   void RunTest() {
1353     XlaBuilder builder(TestName());
1354     std::vector<int64> input_dims = {1, 2, 2, 1024};
1355     std::vector<int64> filter_dims = {2, 2, 128, 512};
1356     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1357     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1358     {
1359       auto input = Parameter(&builder, 0, input_shape, "input");
1360       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1361 
1362       // Tensorflow dimension numbers for 2D convolution.
1363       ConvolutionDimensionNumbers dnums;
1364       dnums.set_input_batch_dimension(0);
1365       dnums.set_output_batch_dimension(0);
1366       dnums.add_input_spatial_dimensions(1);
1367       dnums.add_output_spatial_dimensions(1);
1368       dnums.add_input_spatial_dimensions(2);
1369       dnums.add_output_spatial_dimensions(2);
1370       dnums.set_input_feature_dimension(3);
1371       dnums.set_output_feature_dimension(3);
1372       dnums.add_kernel_spatial_dimensions(0);
1373       dnums.add_kernel_spatial_dimensions(1);
1374       dnums.set_kernel_input_feature_dimension(2);
1375       dnums.set_kernel_output_feature_dimension(3);
1376 
1377       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1378                                 /*feature_group_count=*/8);
1379     }
1380 
1381     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1382                                static_cast<T>(1));
1383 
1384     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1385     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1386 
1387     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1388                                 static_cast<T>(2));
1389 
1390     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1391     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1392 
1393     std::vector<T> output_elems(512, static_cast<T>(1024));
1394     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1395     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).ConsumeValueOrDie();
1396 
1397     auto input_literal =
1398         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1399     auto filter_literal =
1400         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1401 
1402     ComputeAndCompareLiteral(&builder, expected_r4,
1403                              {input_literal.get(), filter_literal.get()},
1404                              error_spec_);
1405   }
1406 };
1407 
1408 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid,Types)1409 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) {
1410   this->RunTest();
1411 }
1412 
1413 template <typename T>
1414 class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest {
1415  public:
RunTest()1416   void RunTest() {
1417     XlaBuilder builder(TestName());
1418     std::vector<int64> input_dims = {1, 2, 2, 1024};
1419     std::vector<int64> filter_dims = {2, 2, 128, 8};
1420     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1421     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1422     {
1423       auto input = Parameter(&builder, 0, input_shape, "input");
1424       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1425 
1426       // Tensorflow dimension numbers for 2D convolution.
1427       ConvolutionDimensionNumbers dnums;
1428       dnums.set_input_batch_dimension(0);
1429       dnums.set_output_batch_dimension(0);
1430       dnums.add_input_spatial_dimensions(1);
1431       dnums.add_output_spatial_dimensions(1);
1432       dnums.add_input_spatial_dimensions(2);
1433       dnums.add_output_spatial_dimensions(2);
1434       dnums.set_input_feature_dimension(3);
1435       dnums.set_output_feature_dimension(3);
1436       dnums.add_kernel_spatial_dimensions(0);
1437       dnums.add_kernel_spatial_dimensions(1);
1438       dnums.set_kernel_input_feature_dimension(2);
1439       dnums.set_kernel_output_feature_dimension(3);
1440 
1441       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1442                                 /*feature_group_count=*/8);
1443     }
1444 
1445     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1446                                static_cast<T>(1));
1447 
1448     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1449     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1450 
1451     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1452                                 static_cast<T>(2));
1453 
1454     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1455     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1456 
1457     std::vector<T> output_elems(8, static_cast<T>(1024));
1458     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1459     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).ConsumeValueOrDie();
1460 
1461     auto input_literal =
1462         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1463     auto filter_literal =
1464         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1465 
1466     ComputeAndCompareLiteral(&builder, expected_r4,
1467                              {input_literal.get(), filter_literal.get()},
1468                              error_spec_);
1469   }
1470 };
1471 
1472 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid,Types)1473 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) {
1474   this->RunTest();
1475 }
1476 
1477 template <typename T>
1478 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest {
1479  public:
RunTest()1480   void RunTest() {
1481     XlaBuilder builder(TestName());
1482     std::vector<int64> input_dims = {1, 2, 2, 12};
1483     std::vector<int64> filter_dims = {2, 2, 3, 4};
1484     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1485     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1486     {
1487       auto input = Parameter(&builder, 0, input_shape, "input");
1488       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1489 
1490       // Tensorflow dimension numbers for 2D convolution.
1491       ConvolutionDimensionNumbers dnums;
1492       dnums.set_input_batch_dimension(0);
1493       dnums.set_output_batch_dimension(0);
1494       dnums.add_input_spatial_dimensions(1);
1495       dnums.add_output_spatial_dimensions(1);
1496       dnums.add_input_spatial_dimensions(2);
1497       dnums.add_output_spatial_dimensions(2);
1498       dnums.set_input_feature_dimension(3);
1499       dnums.set_output_feature_dimension(3);
1500       dnums.add_kernel_spatial_dimensions(0);
1501       dnums.add_kernel_spatial_dimensions(1);
1502       dnums.set_kernel_input_feature_dimension(2);
1503       dnums.set_kernel_output_feature_dimension(3);
1504 
1505       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1506                                 /*feature_group_count=*/4);
1507     }
1508 
1509     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1510     iota_int_init_value(input_elems, 1);
1511     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1512     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1513 
1514     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1515     iota_int_init_value(filter_elems, 1);
1516     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1517     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1518 
1519     auto expected_r1 =
1520         LiteralUtil::CreateR1<T>({static_cast<T>(7712), static_cast<T>(8816),
1521                                   static_cast<T>(9992), static_cast<T>(11240)});
1522     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie();
1523 
1524     auto input_literal =
1525         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1526     auto filter_literal =
1527         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1528 
1529     ComputeAndCompareLiteral(&builder, expected_r4,
1530                              {input_literal.get(), filter_literal.get()},
1531                              error_spec_);
1532   }
1533 };
1534 
1535 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid,Types)1536 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) {
1537   this->RunTest();
1538 }
1539 
1540 template <typename T>
1541 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes
1542     : public ConvolutionTest {
1543  public:
RunTest()1544   void RunTest() {
1545     XlaBuilder builder(TestName());
1546     std::vector<int64> input_dims = {1, 2, 2, 12};
1547     std::vector<int64> filter_dims = {2, 2, 4, 3};
1548     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1549     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1550     {
1551       auto input = Parameter(&builder, 0, input_shape, "input");
1552       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1553 
1554       // Tensorflow dimension numbers for 2D convolution.
1555       ConvolutionDimensionNumbers dnums;
1556       dnums.set_input_batch_dimension(0);
1557       dnums.set_output_batch_dimension(0);
1558       dnums.add_input_spatial_dimensions(1);
1559       dnums.add_output_spatial_dimensions(1);
1560       dnums.add_input_spatial_dimensions(2);
1561       dnums.add_output_spatial_dimensions(2);
1562       dnums.set_input_feature_dimension(3);
1563       dnums.set_output_feature_dimension(3);
1564       dnums.add_kernel_spatial_dimensions(0);
1565       dnums.add_kernel_spatial_dimensions(1);
1566       dnums.set_kernel_input_feature_dimension(3);
1567       dnums.set_kernel_output_feature_dimension(2);
1568 
1569       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1570                                 /*feature_group_count=*/4);
1571     }
1572 
1573     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1574     iota_int_init_value(input_elems, 1);
1575     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1576     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1577 
1578     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1579     iota_int_init_value(filter_elems, 1);
1580     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1581     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1582     auto filter_r4_relaid =
1583         filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
1584     auto expected_r1 = LiteralUtil::CreateR1<T>(
1585         {static_cast<T>(6968), static_cast<T>(8516), static_cast<T>(10280),
1586          static_cast<T>(12260)});
1587     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie();
1588 
1589     auto input_literal =
1590         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1591     auto filter_literal =
1592         client_->TransferToServer(filter_r4_relaid).ConsumeValueOrDie();
1593 
1594     ComputeAndCompareLiteral(&builder, expected_r4,
1595                              {input_literal.get(), filter_literal.get()},
1596                              error_spec_);
1597   }
1598 };
1599 
1600 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1601                 TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,Types)1602 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1603            Types) {
1604   this->RunTest();
1605 }
1606 
1607 template <typename T>
1608 class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest {
1609  public:
RunTest()1610   void RunTest() {
1611     XlaBuilder builder(TestName());
1612     std::vector<int64> input_dims = {1, 1, 1, 12};
1613     std::vector<int64> filter_dims = {1, 1, 3, 4};
1614     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1615     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1616     {
1617       auto input = Parameter(&builder, 0, input_shape, "input");
1618       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1619 
1620       // Tensorflow dimension numbers for 2D convolution.
1621       ConvolutionDimensionNumbers dnums;
1622       dnums.set_input_batch_dimension(0);
1623       dnums.set_output_batch_dimension(0);
1624       dnums.add_input_spatial_dimensions(1);
1625       dnums.add_output_spatial_dimensions(1);
1626       dnums.add_input_spatial_dimensions(2);
1627       dnums.add_output_spatial_dimensions(2);
1628       dnums.set_input_feature_dimension(3);
1629       dnums.set_output_feature_dimension(3);
1630       dnums.add_kernel_spatial_dimensions(0);
1631       dnums.add_kernel_spatial_dimensions(1);
1632       dnums.set_kernel_input_feature_dimension(2);
1633       dnums.set_kernel_output_feature_dimension(3);
1634 
1635       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1636                                 /*feature_group_count=*/4);
1637     }
1638 
1639     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1640     iota_int_init_value(input_elems, 1);
1641     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1642     auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1643 
1644     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1645     iota_int_init_value(filter_elems, 1);
1646     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1647     auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1648 
1649     auto expected_r1 =
1650         LiteralUtil::CreateR1<T>({static_cast<T>(38), static_cast<T>(98),
1651                                   static_cast<T>(176), static_cast<T>(272)});
1652     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie();
1653 
1654     auto input_literal =
1655         client_->TransferToServer(input_r4).ConsumeValueOrDie();
1656     auto filter_literal =
1657         client_->TransferToServer(filter_r4).ConsumeValueOrDie();
1658 
1659     ComputeAndCompareLiteral(&builder, expected_r4,
1660                              {input_literal.get(), filter_literal.get()},
1661                              error_spec_);
1662   }
1663 };
1664 
1665 TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid,Types)1666 TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) {
1667   this->RunTest();
1668 }
1669 
1670 // Test fixture to run convolution tests with and without convolution
1671 // canonicalization enabled.
1672 class ConvolveWithAndWithoutCanonicalization
1673     : public ConvolutionTest,
1674       public ::testing::WithParamInterface<bool> {};
1675 
XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,DISABLED_ON_GPU (Convolve2D_NoSpatialDims))1676 XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
1677            DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) {
1678   if (GetParam()) {
1679     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1680         "convolution-canonicalization");
1681   }
1682   XlaBuilder builder(TestName());
1683   Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
1684   Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
1685 
1686   auto input = Parameter(&builder, 0, input_shape, "input");
1687   auto filter = Parameter(&builder, 1, filter_shape, "filter");
1688 
1689   ConvolutionDimensionNumbers dnums;
1690   dnums.set_input_feature_dimension(0);
1691   dnums.set_input_batch_dimension(1);
1692   dnums.set_kernel_input_feature_dimension(0);
1693   dnums.set_kernel_output_feature_dimension(1);
1694   dnums.set_output_batch_dimension(0);
1695   dnums.set_output_feature_dimension(1);
1696   ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
1697 
1698   Array2D<float> param0(4, 29);
1699   param0.FillUnique();
1700 
1701   Array2D<float> param1(4, 10);
1702   param1.FillUnique();
1703 
1704   Array2D<float> expected_result(29, 10);
1705   expected_result.Fill(0);
1706 
1707   ComputeAndCompare(&builder,
1708                     {LiteralUtil::CreateFromArray(param0),
1709                      LiteralUtil::CreateFromArray(param1)},
1710                     error_spec_);
1711 }
1712 
1713 INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
1714                         ConvolveWithAndWithoutCanonicalization,
1715                         ::testing::Values(true, false));
1716 
1717 struct Convolve1DTestParam {
1718   int64 input_feature;
1719   int64 output_feature;
1720   int64 batch;
1721   int64 window_size;
1722   int64 num_windows;
1723 };
1724 
1725 class Convolve1D1WindowTestBase
1726     : public ConvolutionTest,
1727       public ::testing::WithParamInterface<Convolve1DTestParam> {
1728  protected:
1729   template <typename T>
TestImpl()1730   void TestImpl() {
1731     XlaBuilder builder(TestName());
1732     int64 input_feature = GetParam().input_feature;
1733     int64 output_feature = GetParam().output_feature;
1734     int64 batch = GetParam().batch;
1735     int64 num_windows = GetParam().num_windows;
1736     int64 window_size = GetParam().window_size;
1737     std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
1738                                      input_feature};
1739     std::vector<int64> filter_dims = {window_size, input_feature,
1740                                       output_feature};
1741     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1742     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1743     {
1744       auto input = Parameter(&builder, 0, input_shape, "input");
1745       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1746 
1747       // Tensorflow dimension numbers for 1D convolution.
1748       ConvolutionDimensionNumbers dnums;
1749       dnums.set_input_batch_dimension(0);
1750       dnums.set_output_batch_dimension(0);
1751       dnums.add_input_spatial_dimensions(1);
1752       dnums.add_output_spatial_dimensions(1);
1753       dnums.set_input_feature_dimension(2);
1754       dnums.set_output_feature_dimension(2);
1755       dnums.add_kernel_spatial_dimensions(0);
1756       dnums.set_kernel_input_feature_dimension(1);
1757       dnums.set_kernel_output_feature_dimension(2);
1758 
1759       ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
1760     }
1761 
1762     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1763                                static_cast<T>(1.0f));
1764     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1765     auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1766 
1767     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1768                                 static_cast<T>(1.0f));
1769 
1770     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1771     auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1772 
1773     std::vector<T> expect_elems(batch * output_feature * num_windows,
1774                                 static_cast<T>(window_size * input_feature));
1775     auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
1776     auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
1777                            .ConsumeValueOrDie();
1778 
1779     auto input_literal =
1780         client_->TransferToServer(input_r3).ConsumeValueOrDie();
1781     auto filter_literal =
1782         client_->TransferToServer(filter_r3).ConsumeValueOrDie();
1783     ComputeAndCompareLiteral(&builder, expected_r3,
1784                              {input_literal.get(), filter_literal.get()},
1785                              error_spec_);
1786   }
1787 };
1788 
1789 class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
1790 
XLA_TEST_P(Convolve1D1WindowTestFloat,Convolve1D1Window)1791 XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
1792 
1793 INSTANTIATE_TEST_CASE_P(
1794     Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
1795     ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
1796                       Convolve1DTestParam{160, 1, 1, 5, 1},
1797                       Convolve1DTestParam{24, 1, 1, 20, 1},
1798                       Convolve1DTestParam{30, 1, 1, 20, 1},
1799                       Convolve1DTestParam{23, 1, 1, 20, 20},
1800                       Convolve1DTestParam{25, 1, 1, 20, 1},
1801                       Convolve1DTestParam{24, 1, 1, 10, 5},
1802                       Convolve1DTestParam{160, 1, 1, 10, 1},
1803                       Convolve1DTestParam{255, 1, 1, 3, 1},
1804                       Convolve1DTestParam{130, 1, 1, 1, 3},
1805                       Convolve1DTestParam{64, 1, 1, 1, 1},
1806                       Convolve1DTestParam{128, 1, 1, 1, 1},
1807                       Convolve1DTestParam{139, 1, 1, 128, 1},
1808                       Convolve1DTestParam{1, 10, 10, 1, 10},
1809                       Convolve1DTestParam{1, 10, 130, 1, 2},
1810                       Convolve1DTestParam{1, 10, 130, 1, 1},
1811                       Convolve1DTestParam{1, 64, 64, 1, 10},
1812                       Convolve1DTestParam{1, 65, 65, 1, 1},
1813                       Convolve1DTestParam{1, 128, 128, 1, 1},
1814                       Convolve1DTestParam{128, 128, 128, 128, 1},
1815                       Convolve1DTestParam{1, 128, 128, 1, 1},
1816                       Convolve1DTestParam{2, 2, 2, 2, 1},
1817                       Convolve1DTestParam{161, 1, 1, 10, 1},
1818                       Convolve1DTestParam{900, 1, 1, 10, 1},
1819                       Convolve1DTestParam{640, 3, 3, 128, 1})
1820 
1821 );
1822 
1823 #if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
1824 class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
1825 
XLA_TEST_P(Convolve1D1WindowTestHalf,Convolve1D1Window)1826 XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
1827   TestImpl<Eigen::half>();
1828 }
1829 
1830 INSTANTIATE_TEST_CASE_P(
1831     Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
1832     ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
1833                       Convolve1DTestParam{160, 1, 1, 5, 1},
1834                       Convolve1DTestParam{24, 1, 1, 20, 1},
1835                       Convolve1DTestParam{30, 1, 1, 20, 1},
1836                       Convolve1DTestParam{23, 1, 1, 20, 20},
1837                       Convolve1DTestParam{25, 1, 1, 20, 1},
1838                       Convolve1DTestParam{24, 1, 1, 10, 5},
1839                       Convolve1DTestParam{160, 1, 1, 10, 1},
1840                       Convolve1DTestParam{255, 1, 1, 3, 1},
1841                       Convolve1DTestParam{130, 1, 1, 1, 3},
1842                       Convolve1DTestParam{64, 1, 1, 1, 1},
1843                       Convolve1DTestParam{128, 1, 1, 1, 1},
1844 // TODO(b/72566306): The following five tests failed on CPU with unreasonable
1845 // relative errors.  Last ran on 2018-02-22.
1846 #if XLA_TEST_BACKEND_GPU
1847                       Convolve1DTestParam{139, 1, 1, 128, 1},
1848                       Convolve1DTestParam{640, 3, 3, 128, 1},
1849                       Convolve1DTestParam{900, 1, 1, 10, 1},
1850                       Convolve1DTestParam{1, 10, 10, 1, 10},
1851                       Convolve1DTestParam{1, 10, 130, 1, 1},
1852 #endif
1853                       Convolve1DTestParam{1, 10, 130, 1, 2},
1854                       Convolve1DTestParam{1, 64, 64, 1, 10},
1855                       Convolve1DTestParam{1, 65, 65, 1, 1},
1856                       Convolve1DTestParam{1, 128, 128, 1, 1},
1857                       Convolve1DTestParam{128, 128, 128, 128, 1},
1858                       Convolve1DTestParam{1, 128, 128, 1, 1},
1859                       Convolve1DTestParam{2, 2, 2, 2, 1},
1860                       Convolve1DTestParam{161, 1, 1, 10, 1})
1861 
1862 );
1863 #endif
1864 
XLA_TEST_F(ConvolutionTest,Convolve_bf16_1x1x1x2_1x1x1x2_Valid)1865 XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
1866   XlaBuilder builder(TestName());
1867   Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1868   Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1869   auto input = Parameter(&builder, 0, input_shape, "input");
1870   auto filter = Parameter(&builder, 1, filter_shape, "filter");
1871   Conv(input, filter, {1, 1}, Padding::kValid);
1872 
1873   Array4D<bfloat16> input_data(1, 1, 1, 2);
1874   input_data.FillWithYX(Array2D<bfloat16>({
1875       {bfloat16(1), bfloat16(2)},
1876   }));
1877   Array4D<bfloat16> filter_data(1, 1, 1, 2);
1878   filter_data.FillWithYX(Array2D<bfloat16>({
1879       {bfloat16(5), bfloat16(6)},
1880   }));
1881 
1882   ComputeAndCompare(&builder,
1883                     {LiteralUtil::CreateFromArray(input_data),
1884                      LiteralUtil::CreateFromArray(filter_data)},
1885                     error_spec_);
1886 }
1887 
1888 // Check that GPU convs still work if the CudnnAlgorithmPicker pass is disabled.
1889 // (We run this test on all platforms, because, what the heck.)
XLA_TEST_F(ConvolutionTest,NoCudnnAlgorithmPicker)1890 XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
1891   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1892       "cudnn-conv-algorithm-picker");
1893 
1894   XlaBuilder builder(TestName());
1895   Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1896   Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1897   auto input = Parameter(&builder, 0, input_shape, "input");
1898   auto filter = Parameter(&builder, 1, filter_shape, "filter");
1899   Conv(input, filter, {1, 1}, Padding::kValid);
1900 
1901   Array4D<float> input_data(1, 1, 1, 2);
1902   input_data.FillIota(0);
1903   Array4D<float> filter_data(1, 1, 1, 2);
1904   filter_data.FillIota(10);
1905 
1906   ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
1907                                LiteralUtil::CreateFromArray(filter_data)});
1908 }
1909 
XLA_TEST_F(ConvolutionTest,ConvolveF32BackwardInputGroupedConvolution)1910 XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
1911   XlaBuilder builder(TestName());
1912   Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
1913   Array4D<float> input_data(1, 64, 100, 100);
1914   input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321);
1915   Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
1916   Array4D<float> filter_data(7, 7, 1, 64);
1917   input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320);
1918   auto input = Parameter(&builder, 0, input_shape, "input");
1919   auto filter = ConstantR4FromArray4D(&builder, filter_data);
1920 
1921   // Specify bf01_01io->bf01 as dimension numbers.
1922   ConvolutionDimensionNumbers dnums;
1923   // Input
1924   dnums.set_input_feature_dimension(1);
1925   dnums.set_input_batch_dimension(0);
1926   dnums.add_input_spatial_dimensions(2);
1927   dnums.add_input_spatial_dimensions(3);
1928   // Kernel
1929   dnums.set_kernel_input_feature_dimension(2);
1930   dnums.set_kernel_output_feature_dimension(3);
1931   dnums.add_kernel_spatial_dimensions(0);
1932   dnums.add_kernel_spatial_dimensions(1);
1933   // Output
1934   dnums.set_output_batch_dimension(0);
1935   dnums.set_output_feature_dimension(1);
1936   dnums.add_output_spatial_dimensions(2);
1937   dnums.add_output_spatial_dimensions(3);
1938   ConvGeneral(input, filter, /*window_strides=*/{1, 1},
1939               /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
1940               /*feature_group_count=*/64);
1941 
1942   ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
1943                     error_spec_);
1944 }
1945 
1946 class ConvolutionHloTest : public HloTestBase {};
1947 
XLA_TEST_F(ConvolutionHloTest,ConvolveF64Forward)1948 XLA_TEST_F(ConvolutionHloTest, ConvolveF64Forward) {
1949   constexpr char kHlo[] = R"(
1950 HloModule TestModule
1951 
1952 ENTRY Test {
1953   %arg0 = f64[3,56,56,16] parameter(0)
1954   %arg1 = f64[3,3,3,64] parameter(1)
1955   ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
1956 })";
1957   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1958 }
1959 
XLA_TEST_F(ConvolutionHloTest,ConvolveF32ForwardReversed)1960 XLA_TEST_F(ConvolutionHloTest, ConvolveF32ForwardReversed) {
1961   constexpr char kHlo[] = R"(
1962 HloModule TestModule
1963 
1964 ENTRY Test {
1965   %arg0 = f32[3,56,56,16] parameter(0)
1966   %arg1 = f32[3,3,3,32] parameter(1)
1967   ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf
1968 })";
1969   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1970 }
1971 
XLA_TEST_F(ConvolutionHloTest,ConvolveF64BackwardFilter)1972 XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardFilter) {
1973   constexpr char kHlo[] = R"(
1974 HloModule TestModule
1975 
1976 ENTRY Test {
1977   %arg0 = f64[2,5,8,1] parameter(0)
1978   %arg1 = f64[2,5,8,2] parameter(1)
1979   ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf
1980 })";
1981   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1982 }
1983 
XLA_TEST_F(ConvolutionHloTest,ConvolveF64BackwardInput)1984 XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardInput) {
1985   constexpr char kHlo[] = R"(
1986 HloModule TestModule
1987 
1988 ENTRY Test {
1989   %output = f64[4,5,16,16] parameter(0)
1990   %kernel = f64[5,3,7,7] parameter(1)
1991   %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3}
1992   ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01
1993 })";
1994   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1995 }
1996 
1997 }  // namespace
1998 }  // namespace xla
1999