• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests of 2+D convolution with trivial kernels and no special variations (like
17 // strides and padding).
18 
19 #include <memory>
20 
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/global_data.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/padding.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace xla {
38 namespace {
39 
40 class ConvolutionTest : public ClientLibraryTestBase {
41  protected:
42 #if XLA_TEST_BACKEND_GPU
43   // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
44   // convolution. So relax the absolute error threshold.
45   ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3);
46 #else
47   ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3);
48 #endif
49 };
50 
51 using TestTypes = ::testing::Types<
52 // TODO(b/183565702): Support integer convs on GPU.
53 #if !XLA_TEST_BACKEND_GPU
54     int32_t,
55 #endif
56 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
57     Eigen::half,
58 #endif
59     float>;
60 
61 template <typename T>
62 class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
63  public:
RunTest()64   void RunTest() {
65     const int kInputActivationSizeY = 3;
66     const int kInputActivationSizeX = 3;
67     const int kInputActivationSizeZ = 256;
68     const int kKernelSizeX = 2;
69     const int kKernelSizeY = 2;
70     const int kOutputActivationSizeZ = 256;
71     const int kMiniBatchSize = 4;
72     auto alhs = std::make_unique<Array4D<T>>(
73         kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY,
74         kInputActivationSizeX);
75     alhs->FillWithMultiples(static_cast<T>(static_cast<T>(1.0f)));
76     ASSERT_EQ(3, alhs->width());
77     ASSERT_EQ(3, alhs->height());
78 
79     auto arhs = std::make_unique<Array4D<T>>(kOutputActivationSizeZ,
80                                              kInputActivationSizeZ,
81                                              kKernelSizeY, kKernelSizeX);
82     Array2D<T> rhs_raster({
83         {static_cast<T>(1.0f), static_cast<T>(0.0f)},  // row 0
84         {static_cast<T>(0.0f), static_cast<T>(0.0f)},  // row 1
85     });
86     arhs->FillWithYX(rhs_raster);
87     ASSERT_EQ(2, arhs->width());
88     ASSERT_EQ(2, arhs->height());
89 
90     XlaBuilder builder(TestName());
91     auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
92     auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
93     PrecisionConfig precision;
94     // The left hand side of the convolution is numbers between 0 and 2304 which
95     // requires at least 11 mantissa bits and the DEFAULT precision config is
96     // allowed to round to bfloat16 which only has 7 mantissa bits.
97     precision.add_operand_precision(PrecisionConfig::HIGHEST);
98     precision.add_operand_precision(PrecisionConfig::DEFAULT);
99     Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
100          /*batch_group_count=*/1, &precision);
101 
102     ComputeAndCompare(&builder, {}, error_spec_);
103   }
104 };
105 
106 TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota,Types)107 XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
108   this->RunTest();
109 }
110 
111 template <typename T>
112 class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
113  public:
RunTest()114   void RunTest() {
115     XlaBuilder builder(TestName());
116     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
117     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
118     auto input = Parameter(&builder, 0, input_shape, "input");
119     auto filter = Parameter(&builder, 1, filter_shape, "filter");
120     Conv(input, filter, {1, 1}, Padding::kValid);
121 
122     Array4D<T> input_data(1, 1, 1, 2);
123     input_data.FillWithYX(Array2D<T>({
124         {static_cast<T>(1.0f), static_cast<T>(2.0f)},
125     }));
126     Array4D<T> filter_data(1, 1, 1, 2);
127     filter_data.FillWithYX(Array2D<T>({
128         {static_cast<T>(5.0f), static_cast<T>(6.0f)},
129     }));
130 
131     ComputeAndCompare(&builder,
132                       {LiteralUtil::CreateFromArray(input_data),
133                        LiteralUtil::CreateFromArray(filter_data)},
134                       error_spec_);
135   }
136 };
137 
138 TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid,Types)139 TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
140 
141 // Tests valid padding for 2D convolution in raster space.
142 template <typename T>
143 class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
144  public:
RunTest()145   void RunTest() {
146     XlaBuilder builder(TestName());
147     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
148     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
149     auto input = Parameter(&builder, 0, input_shape, "input");
150     auto filter = Parameter(&builder, 1, filter_shape, "filter");
151     Conv(input, filter, {1, 1}, Padding::kValid);
152 
153     Array4D<T> input_data(1, 1, 4, 4);
154     input_data.FillWithYX(Array2D<T>({
155         {static_cast<T>(1.0f), static_cast<T>(2.0f), static_cast<T>(3.0f),
156          static_cast<T>(4.0f)},
157         {static_cast<T>(5.0f), static_cast<T>(6.0f), static_cast<T>(7.0f),
158          static_cast<T>(8.0f)},
159         {static_cast<T>(9.0f), static_cast<T>(10.0f), static_cast<T>(11.0f),
160          static_cast<T>(12.0f)},
161         {static_cast<T>(13.0f), static_cast<T>(14.0f), static_cast<T>(15.0f),
162          static_cast<T>(16.0f)},
163     }));
164     Array4D<T> filter_data(1, 1, 2, 2);
165     filter_data.FillWithYX(Array2D<T>({
166         {static_cast<T>(5.0f), static_cast<T>(6.0f)},
167         {static_cast<T>(7.0f), static_cast<T>(8.0f)},
168     }));
169     ComputeAndCompare(&builder,
170                       {LiteralUtil::CreateFromArray(input_data),
171                        LiteralUtil::CreateFromArray(filter_data)},
172                       error_spec_);
173   }
174 };
175 
176 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid,Types)177 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
178 
179 // Tests same padding for 2D convolution in raster space.
180 template <typename T>
181 class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
182  public:
RunTest()183   void RunTest() {
184     XlaBuilder builder(TestName());
185     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
186     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
187     auto input = Parameter(&builder, 0, input_shape, "input");
188     auto filter = Parameter(&builder, 1, filter_shape, "filter");
189     Conv(input, filter, {1, 1}, Padding::kSame);
190 
191     Array4D<T> input_data(1, 1, 4, 4);
192     input_data.FillWithYX(Array2D<T>({
193         {static_cast<T>(1.0f), static_cast<T>(2.0f), static_cast<T>(3.0f),
194          static_cast<T>(4.0f)},
195         {static_cast<T>(5.0f), static_cast<T>(6.0f), static_cast<T>(7.0f),
196          static_cast<T>(8.0f)},
197         {static_cast<T>(9.0f), static_cast<T>(10.0f), static_cast<T>(11.0f),
198          static_cast<T>(12.0f)},
199         {static_cast<T>(13.0f), static_cast<T>(14.0f), static_cast<T>(15.0f),
200          static_cast<T>(16.0f)},
201     }));
202     Array4D<T> filter_data(1, 1, 2, 2);
203     filter_data.FillWithYX(Array2D<T>({
204         {static_cast<T>(5.0f), static_cast<T>(6.0f)},
205         {static_cast<T>(7.0f), static_cast<T>(8.0f)},
206     }));
207 
208     ComputeAndCompare(&builder,
209                       {LiteralUtil::CreateFromArray(input_data),
210                        LiteralUtil::CreateFromArray(filter_data)},
211                       error_spec_);
212   }
213 };
214 
215 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same,Types)216 TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
217 
218 // Tests same padding for 2D convolution in raster space with an odd sized
219 // kernel.
220 template <typename T>
221 class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
222  public:
RunTest()223   void RunTest() {
224     XlaBuilder builder(TestName());
225     Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
226     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3});
227     auto input = Parameter(&builder, 0, input_shape, "input");
228     auto filter = Parameter(&builder, 1, filter_shape, "filter");
229     Conv(input, filter, {1, 1}, Padding::kSame);
230 
231     Array4D<T> input_data(1, 1, 4, 4);
232     input_data.FillWithYX(
233         Array2D<T>({{static_cast<T>(1.0f), static_cast<T>(2.0f),
234                      static_cast<T>(3.0f), static_cast<T>(4.0f)},
235                     {static_cast<T>(5.0f), static_cast<T>(6.0f),
236                      static_cast<T>(7.0f), static_cast<T>(8.0f)},
237                     {static_cast<T>(9.0f), static_cast<T>(10.0f),
238                      static_cast<T>(11.0f), static_cast<T>(12.0f)},
239                     {static_cast<T>(13.0f), static_cast<T>(14.0f),
240                      static_cast<T>(15.0f), static_cast<T>(16.0f)}}));
241     Array4D<T> filter_data(1, 1, 3, 3);
242     filter_data.FillWithYX(Array2D<T>(
243         {{static_cast<T>(5.0f), static_cast<T>(6.0f), static_cast<T>(7.0f)},
244          {static_cast<T>(8.0f), static_cast<T>(9.0f), static_cast<T>(10.0f)},
245          {static_cast<T>(11.0f), static_cast<T>(12.0f),
246           static_cast<T>(13.0f)}}));
247     // clang-format on
248     ComputeAndCompare(&builder,
249                       {LiteralUtil::CreateFromArray(input_data),
250                        LiteralUtil::CreateFromArray(filter_data)},
251                       error_spec_);
252   }
253 };
254 
255 TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same,Types)256 TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
257 
XLA_TEST_F(ConvolutionTest,Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)258 XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
259   XlaBuilder builder(TestName());
260   std::vector<int64_t> input_dims = {1, 4, 2, 3, 3};
261   std::vector<int64_t> filter_dims = {2, 2, 2, 3, 3};
262   Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
263   Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
264   {
265     auto input = Parameter(&builder, 0, input_shape, "input");
266     auto filter = Parameter(&builder, 1, filter_shape, "filter");
267 
268     // Tensorflow dimension numbers for 3D convolution.
269     ConvolutionDimensionNumbers dnums;
270     dnums.set_input_batch_dimension(0);
271     dnums.set_output_batch_dimension(0);
272     dnums.add_input_spatial_dimensions(1);
273     dnums.add_output_spatial_dimensions(1);
274     dnums.add_input_spatial_dimensions(2);
275     dnums.add_output_spatial_dimensions(2);
276     dnums.add_input_spatial_dimensions(3);
277     dnums.add_output_spatial_dimensions(3);
278     dnums.set_input_feature_dimension(4);
279     dnums.set_output_feature_dimension(4);
280     dnums.add_kernel_spatial_dimensions(0);
281     dnums.add_kernel_spatial_dimensions(1);
282     dnums.add_kernel_spatial_dimensions(2);
283     dnums.set_kernel_input_feature_dimension(3);
284     dnums.set_kernel_output_feature_dimension(4);
285 
286     ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums);
287   }
288 
289   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
290   iota(input_elems.begin(), input_elems.end(), 1.0f);
291   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
292   auto input_r5 = input_r1.Reshape(input_dims).value();
293 
294   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
295   iota(filter_elems.begin(), filter_elems.end(), 1.0f);
296   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
297   auto filter_r5 = filter_r1.Reshape(filter_dims).value();
298 
299   auto expected_r1 = LiteralUtil::CreateR1<float>(
300       {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
301        38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
302   auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).value();
303 
304   auto input_literal = client_->TransferToServer(input_r5).value();
305   auto filter_literal = client_->TransferToServer(filter_r5).value();
306 
307   ComputeAndCompareLiteral(&builder, expected_r5,
308                            {input_literal.get(), filter_literal.get()},
309                            error_spec_);
310 }
311 
312 // std::iota doesn't work when init_value has a type Eigen::half in some build
313 // servers. The error message is missing the operator ++.
314 template <typename T>
iota_int_init_value(std::vector<T> & values,int init_value)315 void iota_int_init_value(std::vector<T>& values, int init_value) {
316   absl::c_for_each(values,
317                    [&](T& value) { value = static_cast<T>(init_value++); });
318 }
319 
320 template <typename T>
321 class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
322  public:
RunTest()323   void RunTest() {
324     XlaBuilder builder(TestName());
325     std::vector<int64_t> input_dims = {1, 3, 3, 5};
326     std::vector<int64_t> filter_dims = {3, 3, 5, 3};
327     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
328     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
329     {
330       auto input = Parameter(&builder, 0, input_shape, "input");
331       auto filter = Parameter(&builder, 1, filter_shape, "filter");
332 
333       // Tensorflow dimension numbers for 2D convolution.
334       ConvolutionDimensionNumbers dnums;
335       dnums.set_input_batch_dimension(0);
336       dnums.set_output_batch_dimension(0);
337       dnums.add_input_spatial_dimensions(1);
338       dnums.add_output_spatial_dimensions(1);
339       dnums.add_input_spatial_dimensions(2);
340       dnums.add_output_spatial_dimensions(2);
341       dnums.set_input_feature_dimension(3);
342       dnums.set_output_feature_dimension(3);
343       dnums.add_kernel_spatial_dimensions(0);
344       dnums.add_kernel_spatial_dimensions(1);
345       dnums.set_kernel_input_feature_dimension(2);
346       dnums.set_kernel_output_feature_dimension(3);
347 
348       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums);
349     }
350 
351     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
352     iota_int_init_value(input_elems, 1);
353     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
354     auto input_r4 = input_r1.Reshape(input_dims).value();
355 
356     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
357     iota_int_init_value(filter_elems, 1);
358     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
359     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
360 
361     auto expected_r1 = LiteralUtil::CreateR1<T>(
362         {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
363     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).value();
364 
365     auto input_literal = client_->TransferToServer(input_r4).value();
366     auto filter_literal = client_->TransferToServer(filter_r4).value();
367 
368     ComputeAndCompareLiteral(&builder, expected_r4,
369                              {input_literal.get(), filter_literal.get()},
370                              error_spec_);
371   }
372 };
373 
374 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid,Types)375 TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); }
376 
377 template <typename T>
378 class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
379  public:
RunTest()380   void RunTest() {
381     XlaBuilder builder(TestName());
382     std::vector<int64_t> input_dims = {1, 3, 3, 5};
383     std::vector<int64_t> filter_dims = {3, 3, 1, 15};
384     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
385     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
386     {
387       auto input = Parameter(&builder, 0, input_shape, "input");
388       auto filter = Parameter(&builder, 1, filter_shape, "filter");
389 
390       // Tensorflow dimension numbers for 2D convolution.
391       ConvolutionDimensionNumbers dnums;
392       dnums.set_input_batch_dimension(0);
393       dnums.set_output_batch_dimension(0);
394       dnums.add_input_spatial_dimensions(1);
395       dnums.add_output_spatial_dimensions(1);
396       dnums.add_input_spatial_dimensions(2);
397       dnums.add_output_spatial_dimensions(2);
398       dnums.set_input_feature_dimension(3);
399       dnums.set_output_feature_dimension(3);
400       dnums.add_kernel_spatial_dimensions(0);
401       dnums.add_kernel_spatial_dimensions(1);
402       dnums.set_kernel_input_feature_dimension(2);
403       dnums.set_kernel_output_feature_dimension(3);
404 
405       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
406                                 /*feature_group_count=*/5);
407     }
408 
409     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
410     iota_int_init_value(input_elems, 1);
411     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
412     auto input_r4 = input_r1.Reshape(input_dims).value();
413 
414     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
415     iota_int_init_value(filter_elems, 1);
416     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
417     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
418 
419     auto expected_r1 = LiteralUtil::CreateR1<T>(
420         {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
421          static_cast<T>(17172), static_cast<T>(17370), static_cast<T>(17568),
422          static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
423          static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
424          static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
425     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).value();
426 
427     auto input_literal = client_->TransferToServer(input_r4).value();
428     auto filter_literal = client_->TransferToServer(filter_r4).value();
429 
430     ComputeAndCompareLiteral(&builder, expected_r4,
431                              {input_literal.get(), filter_literal.get()},
432                              error_spec_);
433   }
434 };
435 
436 TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid,Types)437 TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) {
438   this->RunTest();
439 }
440 
441 template <typename T>
442 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest {
443  public:
RunTest()444   void RunTest() {
445     XlaBuilder builder(TestName());
446     std::vector<int64_t> input_dims = {1, 4, 4, 5};
447     std::vector<int64_t> filter_dims = {3, 3, 1, 5};
448     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
449     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
450     {
451       auto input = Parameter(&builder, 0, input_shape, "input");
452       auto filter = Parameter(&builder, 1, filter_shape, "filter");
453 
454       // Tensorflow dimension numbers for 2D convolution.
455       ConvolutionDimensionNumbers dnums;
456       dnums.set_input_batch_dimension(0);
457       dnums.set_output_batch_dimension(0);
458       dnums.add_input_spatial_dimensions(1);
459       dnums.add_output_spatial_dimensions(1);
460       dnums.add_input_spatial_dimensions(2);
461       dnums.add_output_spatial_dimensions(2);
462       dnums.set_input_feature_dimension(3);
463       dnums.set_output_feature_dimension(3);
464       dnums.add_kernel_spatial_dimensions(0);
465       dnums.add_kernel_spatial_dimensions(1);
466       dnums.set_kernel_input_feature_dimension(2);
467       dnums.set_kernel_output_feature_dimension(3);
468 
469       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
470                                 /*feature_group_count=*/5);
471     }
472 
473     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
474     iota_int_init_value(input_elems, 1);
475     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
476     auto input_r4 = input_r1.Reshape(input_dims).value();
477 
478     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
479     iota_int_init_value(filter_elems, 1);
480     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
481     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
482 
483     auto expected_r1 = LiteralUtil::CreateR1<T>(
484         {static_cast<T>(6864),  static_cast<T>(7296),  static_cast<T>(7746),
485          static_cast<T>(8214),  static_cast<T>(8700),  static_cast<T>(7809),
486          static_cast<T>(8286),  static_cast<T>(8781),  static_cast<T>(9294),
487          static_cast<T>(9825),  static_cast<T>(10644), static_cast<T>(11256),
488          static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
489          static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
490          static_cast<T>(13614), static_cast<T>(14325)});
491     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).value();
492 
493     auto input_literal = client_->TransferToServer(input_r4).value();
494     auto filter_literal = client_->TransferToServer(filter_r4).value();
495 
496     ComputeAndCompareLiteral(&builder, expected_r4,
497                              {input_literal.get(), filter_literal.get()},
498                              error_spec_);
499 
500     auto filter_r = filter_r1.Reshape(filter_dims);
501   }
502 };
503 
504 TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid,Types)505 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) {
506   this->RunTest();
507 }
508 
509 template <typename T>
510 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest {
511  public:
RunTest()512   void RunTest() {
513     XlaBuilder builder(TestName());
514     std::vector<int64_t> input_dims = {1, 4, 4, 512};
515     std::vector<int64_t> filter_dims = {3, 3, 1, 512};
516     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
517     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
518     {
519       auto input = Parameter(&builder, 0, input_shape, "input");
520       auto filter = Parameter(&builder, 1, filter_shape, "filter");
521 
522       // Tensorflow dimension numbers for 2D convolution.
523       ConvolutionDimensionNumbers dnums;
524       dnums.set_input_batch_dimension(0);
525       dnums.set_output_batch_dimension(0);
526       dnums.add_input_spatial_dimensions(1);
527       dnums.add_output_spatial_dimensions(1);
528       dnums.add_input_spatial_dimensions(2);
529       dnums.add_output_spatial_dimensions(2);
530       dnums.set_input_feature_dimension(3);
531       dnums.set_output_feature_dimension(3);
532       dnums.add_kernel_spatial_dimensions(0);
533       dnums.add_kernel_spatial_dimensions(1);
534       dnums.set_kernel_input_feature_dimension(2);
535       dnums.set_kernel_output_feature_dimension(3);
536 
537       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
538                                 /*feature_group_count=*/512);
539     }
540 
541     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
542                                static_cast<T>(1));
543     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
544     auto input_r4 = input_r1.Reshape(input_dims).value();
545 
546     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
547                                 static_cast<T>(2));
548     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
549     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
550 
551     std::vector<T> output_elems(2048, static_cast<T>(18));
552 
553     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
554     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).value();
555 
556     auto input_literal = client_->TransferToServer(input_r4).value();
557     auto filter_literal = client_->TransferToServer(filter_r4).value();
558 
559     ComputeAndCompareLiteral(&builder, expected_r4,
560                              {input_literal.get(), filter_literal.get()},
561                              error_spec_);
562   }
563 };
564 
565 TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid,Types)566 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) {
567   this->RunTest();
568 }
569 
570 template <typename T>
571 class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes
572     : public ConvolutionTest {
573  public:
RunTest()574   void RunTest() {
575     XlaBuilder builder(TestName());
576     std::vector<int64_t> input_dims = {1, 4, 4, 512};
577     std::vector<int64_t> filter_dims = {3, 3, 1, 512};
578     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
579     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
580     {
581       auto input = Parameter(&builder, 0, input_shape, "input");
582       auto filter = Parameter(&builder, 1, filter_shape, "filter");
583 
584       // Tensorflow dimension numbers for 2D convolution.
585       ConvolutionDimensionNumbers dnums;
586       dnums.set_input_batch_dimension(0);
587       dnums.set_output_batch_dimension(0);
588       dnums.add_input_spatial_dimensions(1);
589       dnums.add_output_spatial_dimensions(1);
590       dnums.add_input_spatial_dimensions(2);
591       dnums.add_output_spatial_dimensions(2);
592       dnums.set_input_feature_dimension(3);
593       dnums.set_output_feature_dimension(3);
594       dnums.add_kernel_spatial_dimensions(0);
595       dnums.add_kernel_spatial_dimensions(1);
596       dnums.set_kernel_input_feature_dimension(2);
597       dnums.set_kernel_output_feature_dimension(3);
598 
599       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
600                                 /*feature_group_count=*/512);
601     }
602 
603     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
604                                static_cast<T>(1));
605     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
606     auto input_r4 = input_r1.Reshape(input_dims).value();
607 
608     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
609                                 static_cast<T>(2));
610     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
611     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
612 
613     std::vector<T> output_elems(2048, static_cast<T>(18));
614 
615     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
616     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).value();
617     auto expected_r4_relaid =
618         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
619 
620     auto input_literal = client_->TransferToServer(input_r4).value();
621     auto filter_literal = client_->TransferToServer(filter_r4).value();
622 
623     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
624                              {input_literal.get(), filter_literal.get()},
625                              error_spec_, &expected_r4_relaid.shape());
626   }
627 };
628 
629 TYPED_TEST_CASE(
630     Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
631     TestTypes);
TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,Types)632 TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes,
633            Types) {
634   this->RunTest();
635 }
636 
637 template <typename T>
638 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes
639     : public ConvolutionTest {
640  public:
RunTest()641   void RunTest() {
642     XlaBuilder builder(TestName());
643     std::vector<int64_t> input_dims = {256, 4, 4, 512};
644     std::vector<int64_t> filter_dims = {3, 3, 1, 512};
645     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
646     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
647     {
648       auto input = Parameter(&builder, 0, input_shape, "input");
649       auto filter = Parameter(&builder, 1, filter_shape, "filter");
650 
651       // Tensorflow dimension numbers for 2D convolution.
652       ConvolutionDimensionNumbers dnums;
653       dnums.set_input_batch_dimension(0);
654       dnums.set_output_batch_dimension(0);
655       dnums.add_input_spatial_dimensions(1);
656       dnums.add_output_spatial_dimensions(1);
657       dnums.add_input_spatial_dimensions(2);
658       dnums.add_output_spatial_dimensions(2);
659       dnums.set_input_feature_dimension(3);
660       dnums.set_output_feature_dimension(3);
661       dnums.add_kernel_spatial_dimensions(0);
662       dnums.add_kernel_spatial_dimensions(1);
663       dnums.set_kernel_input_feature_dimension(2);
664       dnums.set_kernel_output_feature_dimension(3);
665 
666       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
667                                 /*feature_group_count=*/512);
668     }
669 
670     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
671                                static_cast<T>(1));
672     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
673     auto input_r4 = input_r1.Reshape(input_dims).value();
674     auto input_r4_relaid =
675         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
676 
677     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
678                                 static_cast<T>(2));
679     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
680     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
681 
682     std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
683 
684     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
685     auto expected_r4 = expected_r1.Reshape({256, 2, 2, 512}).value();
686 
687     auto input_literal = client_->TransferToServer(input_r4_relaid).value();
688     auto filter_literal = client_->TransferToServer(filter_r4).value();
689 
690     ComputeAndCompareLiteral(&builder, expected_r4,
691                              {input_literal.get(), filter_literal.get()},
692                              error_spec_);
693   }
694 };
695 
696 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
697                 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,Types)698 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes,
699            Types) {
700   this->RunTest();
701 }
702 
703 template <typename T>
704 class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes
705     : public ConvolutionTest {
706  public:
RunTest()707   void RunTest() {
708     XlaBuilder builder(TestName());
709     std::vector<int64_t> input_dims = {256, 4, 4, 512};
710     std::vector<int64_t> filter_dims = {3, 3, 1, 512};
711     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
712     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
713     {
714       auto input = Parameter(&builder, 0, input_shape, "input");
715       auto filter = Parameter(&builder, 1, filter_shape, "filter");
716 
717       // Tensorflow dimension numbers for 2D convolution.
718       ConvolutionDimensionNumbers dnums;
719       dnums.set_input_batch_dimension(0);
720       dnums.set_output_batch_dimension(0);
721       dnums.add_input_spatial_dimensions(1);
722       dnums.add_output_spatial_dimensions(1);
723       dnums.add_input_spatial_dimensions(2);
724       dnums.add_output_spatial_dimensions(2);
725       dnums.set_input_feature_dimension(3);
726       dnums.set_output_feature_dimension(3);
727       dnums.add_kernel_spatial_dimensions(0);
728       dnums.add_kernel_spatial_dimensions(1);
729       dnums.set_kernel_input_feature_dimension(2);
730       dnums.set_kernel_output_feature_dimension(3);
731 
732       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
733                                 /*feature_group_count=*/512);
734     }
735 
736     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
737                                static_cast<T>(1));
738     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
739     auto input_r4 = input_r1.Reshape(input_dims).value();
740     auto input_r4_relaid =
741         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
742 
743     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
744                                 static_cast<T>(2));
745     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
746     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
747 
748     std::vector<T> output_elems(2048 * 256, static_cast<T>(18));
749 
750     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
751     auto expected_r4 = expected_r1.Reshape({256, 2, 2, 512}).value();
752     auto expected_r4_relaid =
753         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
754 
755     auto input_literal = client_->TransferToServer(input_r4_relaid).value();
756     auto filter_literal = client_->TransferToServer(filter_r4).value();
757 
758     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
759                              {input_literal.get(), filter_literal.get()},
760                              error_spec_, &expected_r4_relaid.shape());
761   }
762 };
763 
764 TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
765                 TestTypes);
TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,Types)766 TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes,
767            Types) {
768   this->RunTest();
769 }
770 
771 template <typename T>
772 class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes
773     : public ConvolutionTest {
774  public:
RunTest()775   void RunTest() {
776     XlaBuilder builder(TestName());
777     std::vector<int64_t> input_dims = {1, 4, 4, 5};
778     std::vector<int64_t> filter_dims = {3, 3, 1, 5};
779     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
780     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
781     {
782       auto input = Parameter(&builder, 0, input_shape, "input");
783       auto filter = Parameter(&builder, 1, filter_shape, "filter");
784 
785       // Tensorflow dimension numbers for 2D convolution.
786       ConvolutionDimensionNumbers dnums;
787       dnums.set_input_batch_dimension(0);
788       dnums.set_output_batch_dimension(0);
789       dnums.add_input_spatial_dimensions(1);
790       dnums.add_output_spatial_dimensions(1);
791       dnums.add_input_spatial_dimensions(2);
792       dnums.add_output_spatial_dimensions(2);
793       dnums.set_input_feature_dimension(3);
794       dnums.set_output_feature_dimension(3);
795       dnums.add_kernel_spatial_dimensions(0);
796       dnums.add_kernel_spatial_dimensions(1);
797       dnums.set_kernel_input_feature_dimension(2);
798       dnums.set_kernel_output_feature_dimension(3);
799 
800       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
801                                 /*feature_group_count=*/5);
802     }
803 
804     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
805     iota_int_init_value(input_elems, 1);
806     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
807     auto input_r4 = input_r1.Reshape(input_dims).value();
808     auto input_r4_relaid =
809         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
810 
811     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
812     iota_int_init_value(filter_elems, 1);
813     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
814     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
815 
816     auto expected_r1 = LiteralUtil::CreateR1<T>(
817         {static_cast<T>(6864),  static_cast<T>(7296),  static_cast<T>(7746),
818          static_cast<T>(8214),  static_cast<T>(8700),  static_cast<T>(7809),
819          static_cast<T>(8286),  static_cast<T>(8781),  static_cast<T>(9294),
820          static_cast<T>(9825),  static_cast<T>(10644), static_cast<T>(11256),
821          static_cast<T>(11886), static_cast<T>(12534), static_cast<T>(13200),
822          static_cast<T>(11589), static_cast<T>(12246), static_cast<T>(12921),
823          static_cast<T>(13614), static_cast<T>(14325)});
824     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).value();
825     auto expected_r4_relaid =
826         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
827 
828     auto input_literal = client_->TransferToServer(input_r4_relaid).value();
829     auto filter_literal = client_->TransferToServer(filter_r4).value();
830 
831     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
832                              {input_literal.get(), filter_literal.get()},
833                              error_spec_, &expected_r4_relaid.shape());
834   }
835 };
836 
837 TYPED_TEST_CASE(
838     Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
839     TestTypes);
TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,Types)840 TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes,
841            Types) {
842   this->RunTest();
843 }
844 
845 template <typename T>
846 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest {
847  public:
RunTest()848   void RunTest() {
849     XlaBuilder builder(TestName());
850     std::vector<int64_t> input_dims = {1, 4, 4, 160};
851     std::vector<int64_t> filter_dims = {3, 3, 1, 160};
852     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
853     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
854     {
855       auto input = Parameter(&builder, 0, input_shape, "input");
856       auto filter = Parameter(&builder, 1, filter_shape, "filter");
857 
858       // Tensorflow dimension numbers for 2D convolution.
859       ConvolutionDimensionNumbers dnums;
860       dnums.set_input_batch_dimension(0);
861       dnums.set_output_batch_dimension(0);
862       dnums.add_input_spatial_dimensions(1);
863       dnums.add_output_spatial_dimensions(1);
864       dnums.add_input_spatial_dimensions(2);
865       dnums.add_output_spatial_dimensions(2);
866       dnums.set_input_feature_dimension(3);
867       dnums.set_output_feature_dimension(3);
868       dnums.add_kernel_spatial_dimensions(0);
869       dnums.add_kernel_spatial_dimensions(1);
870       dnums.set_kernel_input_feature_dimension(2);
871       dnums.set_kernel_output_feature_dimension(3);
872 
873       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
874                                 /*feature_group_count=*/160);
875     }
876 
877     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
878                                static_cast<T>(1));
879     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
880     auto input_r4 = input_r1.Reshape(input_dims).value();
881 
882     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
883                                 static_cast<T>(2));
884     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
885     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
886 
887     std::vector<T> output_elems(640, static_cast<T>(18));
888 
889     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
890     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).value();
891 
892     auto input_literal = client_->TransferToServer(input_r4).value();
893     auto filter_literal = client_->TransferToServer(filter_r4).value();
894 
895     ComputeAndCompareLiteral(&builder, expected_r4,
896                              {input_literal.get(), filter_literal.get()},
897                              error_spec_);
898   }
899 };
900 
901 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid,Types)902 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) {
903   this->RunTest();
904 }
905 
906 template <typename T>
907 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes
908     : public ConvolutionTest {
909  public:
RunTest()910   void RunTest() {
911     XlaBuilder builder(TestName());
912     std::vector<int64_t> input_dims = {1, 4, 4, 160};
913     std::vector<int64_t> filter_dims = {3, 3, 1, 160};
914     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
915     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
916     {
917       auto input = Parameter(&builder, 0, input_shape, "input");
918       auto filter = Parameter(&builder, 1, filter_shape, "filter");
919 
920       // Tensorflow dimension numbers for 2D convolution.
921       ConvolutionDimensionNumbers dnums;
922       dnums.set_input_batch_dimension(0);
923       dnums.set_output_batch_dimension(0);
924       dnums.add_input_spatial_dimensions(1);
925       dnums.add_output_spatial_dimensions(1);
926       dnums.add_input_spatial_dimensions(2);
927       dnums.add_output_spatial_dimensions(2);
928       dnums.set_input_feature_dimension(3);
929       dnums.set_output_feature_dimension(3);
930       dnums.add_kernel_spatial_dimensions(0);
931       dnums.add_kernel_spatial_dimensions(1);
932       dnums.set_kernel_input_feature_dimension(2);
933       dnums.set_kernel_output_feature_dimension(3);
934 
935       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
936                                 /*feature_group_count=*/160);
937     }
938 
939     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
940                                static_cast<T>(1));
941     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
942     auto input_r4 = input_r1.Reshape(input_dims).value();
943     auto input_r4_relaid =
944         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
945 
946     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
947                                 static_cast<T>(2));
948     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
949     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
950 
951     std::vector<T> output_elems(640, static_cast<T>(18));
952 
953     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
954     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).value();
955     auto expected_r4_relaid =
956         expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1}));
957 
958     auto input_literal = client_->TransferToServer(input_r4_relaid).value();
959     auto filter_literal = client_->TransferToServer(filter_r4).value();
960 
961     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
962                              {input_literal.get(), filter_literal.get()},
963                              error_spec_, &expected_r4_relaid.shape());
964   }
965 };
966 
967 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
968                 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,Types)969 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes,
970            Types) {
971   this->RunTest();
972 }
973 
974 template <typename T>
975 class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes
976     : public ConvolutionTest {
977  public:
RunTest()978   void RunTest() {
979     XlaBuilder builder(TestName());
980     std::vector<int64_t> input_dims = {1, 4, 4, 160};
981     std::vector<int64_t> filter_dims = {3, 3, 1, 160};
982     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
983     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
984     {
985       auto input = Parameter(&builder, 0, input_shape, "input");
986       auto filter = Parameter(&builder, 1, filter_shape, "filter");
987 
988       // Tensorflow dimension numbers for 2D convolution.
989       ConvolutionDimensionNumbers dnums;
990       dnums.set_input_batch_dimension(0);
991       dnums.set_output_batch_dimension(0);
992       dnums.add_input_spatial_dimensions(1);
993       dnums.add_output_spatial_dimensions(1);
994       dnums.add_input_spatial_dimensions(2);
995       dnums.add_output_spatial_dimensions(2);
996       dnums.set_input_feature_dimension(3);
997       dnums.set_output_feature_dimension(3);
998       dnums.add_kernel_spatial_dimensions(0);
999       dnums.add_kernel_spatial_dimensions(1);
1000       dnums.set_kernel_input_feature_dimension(2);
1001       dnums.set_kernel_output_feature_dimension(3);
1002 
1003       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1004                                 /*feature_group_count=*/160);
1005     }
1006 
1007     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1008                                static_cast<T>(1));
1009     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1010     auto input_r4 = input_r1.Reshape(input_dims).value();
1011     auto input_r4_relaid =
1012         input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1013 
1014     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1015                                 static_cast<T>(2));
1016     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1017     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1018 
1019     std::vector<T> output_elems(640, static_cast<T>(18));
1020 
1021     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1022     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).value();
1023     auto expected_r4_relaid =
1024         expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1}));
1025 
1026     auto input_literal = client_->TransferToServer(input_r4_relaid).value();
1027     auto filter_literal = client_->TransferToServer(filter_r4).value();
1028 
1029     ComputeAndCompareLiteral(&builder, expected_r4_relaid,
1030                              {input_literal.get(), filter_literal.get()},
1031                              error_spec_, &expected_r4_relaid.shape());
1032   }
1033 };
1034 
1035 TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes,
1036                 TestTypes);
TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes,Types)1037 TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Both_Batch_In_Lanes,
1038            Types) {
1039   this->RunTest();
1040 }
1041 
1042 template <typename T>
1043 class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid
1044     : public ConvolutionTest {
1045  public:
RunTest()1046   void RunTest() {
1047     XlaBuilder builder(TestName());
1048     std::vector<int64_t> input_dims = {1, 4, 4, 1024};
1049     std::vector<int64_t> filter_dims = {3, 3, 1, 1024};
1050     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1051     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1052     {
1053       auto input = Parameter(&builder, 0, input_shape, "input");
1054       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1055 
1056       // Tensorflow dimension numbers for 2D convolution.
1057       ConvolutionDimensionNumbers dnums;
1058       dnums.set_input_batch_dimension(0);
1059       dnums.set_output_batch_dimension(0);
1060       dnums.add_input_spatial_dimensions(1);
1061       dnums.add_output_spatial_dimensions(1);
1062       dnums.add_input_spatial_dimensions(2);
1063       dnums.add_output_spatial_dimensions(2);
1064       dnums.set_input_feature_dimension(3);
1065       dnums.set_output_feature_dimension(3);
1066       dnums.add_kernel_spatial_dimensions(0);
1067       dnums.add_kernel_spatial_dimensions(1);
1068       dnums.set_kernel_input_feature_dimension(2);
1069       dnums.set_kernel_output_feature_dimension(3);
1070 
1071       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1072                                 /*feature_group_count=*/1024);
1073     }
1074 
1075     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1076                                static_cast<T>(1));
1077     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1078     auto input_r4 = input_r1.Reshape(input_dims).value();
1079 
1080     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1081                                 static_cast<T>(2));
1082     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1083     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1084 
1085     std::vector<T> output_elems(4096, static_cast<T>(18));
1086 
1087     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1088     auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).value();
1089 
1090     auto input_literal = client_->TransferToServer(input_r4).value();
1091     auto filter_literal = client_->TransferToServer(filter_r4).value();
1092 
1093     ComputeAndCompareLiteral(&builder, expected_r4,
1094                              {input_literal.get(), filter_literal.get()},
1095                              error_spec_);
1096   }
1097 };
1098 
1099 TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid,Types)1100 TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) {
1101   this->RunTest();
1102 }
1103 
1104 template <typename T>
1105 class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest {
1106  public:
RunTest()1107   void RunTest() {
1108     XlaBuilder builder(TestName());
1109     std::vector<int64_t> input_dims = {1, 2, 2, 6};
1110     std::vector<int64_t> filter_dims = {2, 2, 2, 12};
1111     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1112     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1113     {
1114       auto input = Parameter(&builder, 0, input_shape, "input");
1115       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1116 
1117       // Tensorflow dimension numbers for 2D convolution.
1118       ConvolutionDimensionNumbers dnums;
1119       dnums.set_input_batch_dimension(0);
1120       dnums.set_output_batch_dimension(0);
1121       dnums.add_input_spatial_dimensions(1);
1122       dnums.add_output_spatial_dimensions(1);
1123       dnums.add_input_spatial_dimensions(2);
1124       dnums.add_output_spatial_dimensions(2);
1125       dnums.set_input_feature_dimension(3);
1126       dnums.set_output_feature_dimension(3);
1127       dnums.add_kernel_spatial_dimensions(0);
1128       dnums.add_kernel_spatial_dimensions(1);
1129       dnums.set_kernel_input_feature_dimension(2);
1130       dnums.set_kernel_output_feature_dimension(3);
1131 
1132       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1133                                 /*feature_group_count=*/3);
1134     }
1135 
1136     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1137     iota_int_init_value(input_elems, 1);
1138     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1139     auto input_r4 = input_r1.Reshape(input_dims).value();
1140 
1141     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1142     iota_int_init_value(filter_elems, 1);
1143     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1144     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1145 
1146     auto expected_r1 = LiteralUtil::CreateR1<T>(
1147         {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
1148          static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
1149          static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
1150          static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
1151     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).value();
1152 
1153     auto input_literal = client_->TransferToServer(input_r4).value();
1154     auto filter_literal = client_->TransferToServer(filter_r4).value();
1155 
1156     ComputeAndCompareLiteral(&builder, expected_r4,
1157                              {input_literal.get(), filter_literal.get()},
1158                              error_spec_);
1159   }
1160 };
1161 
1162 TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid,Types)1163 TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) {
1164   this->RunTest();
1165 }
1166 
1167 template <typename T>
1168 class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest {
1169  public:
RunTest()1170   void RunTest() {
1171     XlaBuilder builder(TestName());
1172     std::vector<int64_t> input_dims = {1, 2, 2, 1024};
1173     std::vector<int64_t> filter_dims = {2, 2, 128, 512};
1174     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1175     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1176     {
1177       auto input = Parameter(&builder, 0, input_shape, "input");
1178       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1179 
1180       // Tensorflow dimension numbers for 2D convolution.
1181       ConvolutionDimensionNumbers dnums;
1182       dnums.set_input_batch_dimension(0);
1183       dnums.set_output_batch_dimension(0);
1184       dnums.add_input_spatial_dimensions(1);
1185       dnums.add_output_spatial_dimensions(1);
1186       dnums.add_input_spatial_dimensions(2);
1187       dnums.add_output_spatial_dimensions(2);
1188       dnums.set_input_feature_dimension(3);
1189       dnums.set_output_feature_dimension(3);
1190       dnums.add_kernel_spatial_dimensions(0);
1191       dnums.add_kernel_spatial_dimensions(1);
1192       dnums.set_kernel_input_feature_dimension(2);
1193       dnums.set_kernel_output_feature_dimension(3);
1194 
1195       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1196                                 /*feature_group_count=*/8);
1197     }
1198 
1199     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1200                                static_cast<T>(1));
1201 
1202     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1203     auto input_r4 = input_r1.Reshape(input_dims).value();
1204 
1205     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1206                                 static_cast<T>(2));
1207 
1208     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1209     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1210 
1211     std::vector<T> output_elems(512, static_cast<T>(1024));
1212     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1213     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).value();
1214 
1215     auto input_literal = client_->TransferToServer(input_r4).value();
1216     auto filter_literal = client_->TransferToServer(filter_r4).value();
1217 
1218     ComputeAndCompareLiteral(&builder, expected_r4,
1219                              {input_literal.get(), filter_literal.get()},
1220                              error_spec_);
1221   }
1222 };
1223 
1224 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid,Types)1225 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) {
1226   this->RunTest();
1227 }
1228 
1229 template <typename T>
1230 class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest {
1231  public:
RunTest()1232   void RunTest() {
1233     XlaBuilder builder(TestName());
1234     std::vector<int64_t> input_dims = {1, 2, 2, 1024};
1235     std::vector<int64_t> filter_dims = {2, 2, 128, 8};
1236     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1237     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1238     {
1239       auto input = Parameter(&builder, 0, input_shape, "input");
1240       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1241 
1242       // Tensorflow dimension numbers for 2D convolution.
1243       ConvolutionDimensionNumbers dnums;
1244       dnums.set_input_batch_dimension(0);
1245       dnums.set_output_batch_dimension(0);
1246       dnums.add_input_spatial_dimensions(1);
1247       dnums.add_output_spatial_dimensions(1);
1248       dnums.add_input_spatial_dimensions(2);
1249       dnums.add_output_spatial_dimensions(2);
1250       dnums.set_input_feature_dimension(3);
1251       dnums.set_output_feature_dimension(3);
1252       dnums.add_kernel_spatial_dimensions(0);
1253       dnums.add_kernel_spatial_dimensions(1);
1254       dnums.set_kernel_input_feature_dimension(2);
1255       dnums.set_kernel_output_feature_dimension(3);
1256 
1257       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1258                                 /*feature_group_count=*/8);
1259     }
1260 
1261     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
1262                                static_cast<T>(1));
1263 
1264     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1265     auto input_r4 = input_r1.Reshape(input_dims).value();
1266 
1267     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
1268                                 static_cast<T>(2));
1269 
1270     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1271     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1272 
1273     std::vector<T> output_elems(8, static_cast<T>(1024));
1274     auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
1275     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).value();
1276 
1277     auto input_literal = client_->TransferToServer(input_r4).value();
1278     auto filter_literal = client_->TransferToServer(filter_r4).value();
1279 
1280     ComputeAndCompareLiteral(&builder, expected_r4,
1281                              {input_literal.get(), filter_literal.get()},
1282                              error_spec_);
1283   }
1284 };
1285 
1286 TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid,Types)1287 TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) {
1288   this->RunTest();
1289 }
1290 
1291 template <typename T>
1292 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest {
1293  public:
RunTest()1294   void RunTest() {
1295     XlaBuilder builder(TestName());
1296     std::vector<int64_t> input_dims = {1, 2, 2, 12};
1297     std::vector<int64_t> filter_dims = {2, 2, 3, 4};
1298     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1299     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1300     {
1301       auto input = Parameter(&builder, 0, input_shape, "input");
1302       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1303 
1304       // Tensorflow dimension numbers for 2D convolution.
1305       ConvolutionDimensionNumbers dnums;
1306       dnums.set_input_batch_dimension(0);
1307       dnums.set_output_batch_dimension(0);
1308       dnums.add_input_spatial_dimensions(1);
1309       dnums.add_output_spatial_dimensions(1);
1310       dnums.add_input_spatial_dimensions(2);
1311       dnums.add_output_spatial_dimensions(2);
1312       dnums.set_input_feature_dimension(3);
1313       dnums.set_output_feature_dimension(3);
1314       dnums.add_kernel_spatial_dimensions(0);
1315       dnums.add_kernel_spatial_dimensions(1);
1316       dnums.set_kernel_input_feature_dimension(2);
1317       dnums.set_kernel_output_feature_dimension(3);
1318 
1319       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1320                                 /*feature_group_count=*/4);
1321     }
1322 
1323     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1324     iota_int_init_value(input_elems, 1);
1325     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1326     auto input_r4 = input_r1.Reshape(input_dims).value();
1327 
1328     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1329     iota_int_init_value(filter_elems, 1);
1330     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1331     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1332 
1333     auto expected_r1 =
1334         LiteralUtil::CreateR1<T>({static_cast<T>(7712), static_cast<T>(8816),
1335                                   static_cast<T>(9992), static_cast<T>(11240)});
1336     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).value();
1337 
1338     auto input_literal = client_->TransferToServer(input_r4).value();
1339     auto filter_literal = client_->TransferToServer(filter_r4).value();
1340 
1341     ComputeAndCompareLiteral(&builder, expected_r4,
1342                              {input_literal.get(), filter_literal.get()},
1343                              error_spec_);
1344   }
1345 };
1346 
1347 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid,Types)1348 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) {
1349   this->RunTest();
1350 }
1351 
1352 template <typename T>
1353 class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes
1354     : public ConvolutionTest {
1355  public:
RunTest()1356   void RunTest() {
1357     XlaBuilder builder(TestName());
1358     std::vector<int64_t> input_dims = {1, 2, 2, 12};
1359     std::vector<int64_t> filter_dims = {2, 2, 4, 3};
1360     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1361     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1362     {
1363       auto input = Parameter(&builder, 0, input_shape, "input");
1364       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1365 
1366       // Tensorflow dimension numbers for 2D convolution.
1367       ConvolutionDimensionNumbers dnums;
1368       dnums.set_input_batch_dimension(0);
1369       dnums.set_output_batch_dimension(0);
1370       dnums.add_input_spatial_dimensions(1);
1371       dnums.add_output_spatial_dimensions(1);
1372       dnums.add_input_spatial_dimensions(2);
1373       dnums.add_output_spatial_dimensions(2);
1374       dnums.set_input_feature_dimension(3);
1375       dnums.set_output_feature_dimension(3);
1376       dnums.add_kernel_spatial_dimensions(0);
1377       dnums.add_kernel_spatial_dimensions(1);
1378       dnums.set_kernel_input_feature_dimension(3);
1379       dnums.set_kernel_output_feature_dimension(2);
1380 
1381       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1382                                 /*feature_group_count=*/4);
1383     }
1384 
1385     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1386     iota_int_init_value(input_elems, 1);
1387     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1388     auto input_r4 = input_r1.Reshape(input_dims).value();
1389 
1390     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1391     iota_int_init_value(filter_elems, 1);
1392     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1393     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1394     auto filter_r4_relaid =
1395         filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
1396     auto expected_r1 = LiteralUtil::CreateR1<T>(
1397         {static_cast<T>(6968), static_cast<T>(8516), static_cast<T>(10280),
1398          static_cast<T>(12260)});
1399     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).value();
1400 
1401     auto input_literal = client_->TransferToServer(input_r4).value();
1402     auto filter_literal = client_->TransferToServer(filter_r4_relaid).value();
1403 
1404     ComputeAndCompareLiteral(&builder, expected_r4,
1405                              {input_literal.get(), filter_literal.get()},
1406                              error_spec_);
1407   }
1408 };
1409 
1410 TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1411                 TestTypes);
TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,Types)1412 TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes,
1413            Types) {
1414   this->RunTest();
1415 }
1416 
1417 template <typename T>
1418 class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest {
1419  public:
RunTest()1420   void RunTest() {
1421     XlaBuilder builder(TestName());
1422     std::vector<int64_t> input_dims = {1, 1, 1, 12};
1423     std::vector<int64_t> filter_dims = {1, 1, 3, 4};
1424     Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
1425     Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
1426     {
1427       auto input = Parameter(&builder, 0, input_shape, "input");
1428       auto filter = Parameter(&builder, 1, filter_shape, "filter");
1429 
1430       // Tensorflow dimension numbers for 2D convolution.
1431       ConvolutionDimensionNumbers dnums;
1432       dnums.set_input_batch_dimension(0);
1433       dnums.set_output_batch_dimension(0);
1434       dnums.add_input_spatial_dimensions(1);
1435       dnums.add_output_spatial_dimensions(1);
1436       dnums.add_input_spatial_dimensions(2);
1437       dnums.add_output_spatial_dimensions(2);
1438       dnums.set_input_feature_dimension(3);
1439       dnums.set_output_feature_dimension(3);
1440       dnums.add_kernel_spatial_dimensions(0);
1441       dnums.add_kernel_spatial_dimensions(1);
1442       dnums.set_kernel_input_feature_dimension(2);
1443       dnums.set_kernel_output_feature_dimension(3);
1444 
1445       ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
1446                                 /*feature_group_count=*/4);
1447     }
1448 
1449     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
1450     iota_int_init_value(input_elems, 1);
1451     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
1452     auto input_r4 = input_r1.Reshape(input_dims).value();
1453 
1454     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1455     iota_int_init_value(filter_elems, 1);
1456     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
1457     auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1458 
1459     auto expected_r1 =
1460         LiteralUtil::CreateR1<T>({static_cast<T>(38), static_cast<T>(98),
1461                                   static_cast<T>(176), static_cast<T>(272)});
1462     auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).value();
1463 
1464     auto input_literal = client_->TransferToServer(input_r4).value();
1465     auto filter_literal = client_->TransferToServer(filter_r4).value();
1466 
1467     ComputeAndCompareLiteral(&builder, expected_r4,
1468                              {input_literal.get(), filter_literal.get()},
1469                              error_spec_);
1470   }
1471 };
1472 
1473 TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid,Types)1474 TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) {
1475   this->RunTest();
1476 }
1477 
1478 // Test fixture to run convolution tests with and without convolution
1479 // canonicalization enabled.
1480 class ConvolveWithAndWithoutCanonicalization
1481     : public ConvolutionTest,
1482       public ::testing::WithParamInterface<bool> {};
1483 
XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,DISABLED_ON_GPU (Convolve2D_NoSpatialDims))1484 XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
1485            DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) {
1486   if (GetParam()) {
1487     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1488         "convolution-canonicalization");
1489   }
1490   XlaBuilder builder(TestName());
1491   Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
1492   Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
1493 
1494   auto input = Parameter(&builder, 0, input_shape, "input");
1495   auto filter = Parameter(&builder, 1, filter_shape, "filter");
1496 
1497   ConvolutionDimensionNumbers dnums;
1498   dnums.set_input_feature_dimension(0);
1499   dnums.set_input_batch_dimension(1);
1500   dnums.set_kernel_input_feature_dimension(0);
1501   dnums.set_kernel_output_feature_dimension(1);
1502   dnums.set_output_batch_dimension(0);
1503   dnums.set_output_feature_dimension(1);
1504   ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
1505 
1506   Array2D<float> param0(4, 29);
1507   param0.FillUnique();
1508 
1509   Array2D<float> param1(4, 10);
1510   param1.FillUnique();
1511 
1512   Array2D<float> expected_result(29, 10);
1513   expected_result.Fill(0);
1514 
1515   ComputeAndCompare(&builder,
1516                     {LiteralUtil::CreateFromArray(param0),
1517                      LiteralUtil::CreateFromArray(param1)},
1518                     error_spec_);
1519 }
1520 
1521 INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
1522                         ConvolveWithAndWithoutCanonicalization,
1523                         ::testing::Values(true, false));
1524 
XLA_TEST_F(ConvolutionTest,Convolve_bf16_1x1x1x2_1x1x1x2_Valid)1525 XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
1526   XlaBuilder builder(TestName());
1527   Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1528   Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
1529   auto input = Parameter(&builder, 0, input_shape, "input");
1530   auto filter = Parameter(&builder, 1, filter_shape, "filter");
1531   Conv(input, filter, {1, 1}, Padding::kValid);
1532 
1533   Array4D<bfloat16> input_data(1, 1, 1, 2);
1534   input_data.FillWithYX(Array2D<bfloat16>({
1535       {bfloat16(1), bfloat16(2)},
1536   }));
1537   Array4D<bfloat16> filter_data(1, 1, 1, 2);
1538   filter_data.FillWithYX(Array2D<bfloat16>({
1539       {bfloat16(5), bfloat16(6)},
1540   }));
1541 
1542   ComputeAndCompare(&builder,
1543                     {LiteralUtil::CreateFromArray(input_data),
1544                      LiteralUtil::CreateFromArray(filter_data)},
1545                     error_spec_);
1546 }
1547 
1548 // Check that GPU convs still work if the CudnnAlgorithmPicker pass is disabled.
1549 // (We run this test on all platforms, because, what the heck.)
XLA_TEST_F(ConvolutionTest,DISABLED_ON_GPU_ROCM (NoCudnnAlgorithmPicker))1550 XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU_ROCM(NoCudnnAlgorithmPicker)) {
1551   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
1552       "gpu-conv-algorithm-picker");
1553 
1554   XlaBuilder builder(TestName());
1555   Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1556   Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1557   auto input = Parameter(&builder, 0, input_shape, "input");
1558   auto filter = Parameter(&builder, 1, filter_shape, "filter");
1559   Conv(input, filter, {1, 1}, Padding::kValid);
1560 
1561   Array4D<float> input_data(1, 1, 1, 2);
1562   input_data.FillIota(0);
1563   Array4D<float> filter_data(1, 1, 1, 2);
1564   filter_data.FillIota(10);
1565 
1566   ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
1567                                LiteralUtil::CreateFromArray(filter_data)});
1568 }
1569 
XLA_TEST_F(ConvolutionTest,ConvolveF32BackwardInputGroupedConvolution)1570 XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
1571   XlaBuilder builder(TestName());
1572   Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
1573   Array4D<float> input_data(1, 64, 100, 100);
1574   input_data.FillRandom(/*stddev=*/0.023, 0.001, /*seed=*/45321);
1575   Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
1576   Array4D<float> filter_data(7, 7, 1, 64);
1577   filter_data.FillRandom(/*stddev=*/0.023, 0.001, /*seed=*/45320);
1578   auto input = Parameter(&builder, 0, input_shape, "input");
1579   auto filter = ConstantR4FromArray4D(&builder, filter_data);
1580 
1581   // Specify bf01_01io->bf01 as dimension numbers.
1582   ConvolutionDimensionNumbers dnums;
1583   // Input
1584   dnums.set_input_feature_dimension(1);
1585   dnums.set_input_batch_dimension(0);
1586   dnums.add_input_spatial_dimensions(2);
1587   dnums.add_input_spatial_dimensions(3);
1588   // Kernel
1589   dnums.set_kernel_input_feature_dimension(2);
1590   dnums.set_kernel_output_feature_dimension(3);
1591   dnums.add_kernel_spatial_dimensions(0);
1592   dnums.add_kernel_spatial_dimensions(1);
1593   // Output
1594   dnums.set_output_batch_dimension(0);
1595   dnums.set_output_feature_dimension(1);
1596   dnums.add_output_spatial_dimensions(2);
1597   dnums.add_output_spatial_dimensions(3);
1598   ConvGeneral(input, filter, /*window_strides=*/{1, 1},
1599               /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
1600               /*feature_group_count=*/64);
1601 
1602   ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
1603                     error_spec_);
1604 }
1605 
1606 class ConvolutionHloTest : public HloTestBase {};
1607 
1608 // double datatype is not yet supported in ROCm
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF64Forward))1609 XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64Forward)) {
1610   constexpr char kHlo[] = R"(
1611 HloModule TestModule
1612 
1613 ENTRY Test {
1614   %arg0 = f64[3,56,56,16] parameter(0)
1615   %arg1 = f64[3,3,3,64] parameter(1)
1616   ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
1617 })";
1618   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1619 }
1620 
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU (ConvolveC64Forward))1621 XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(ConvolveC64Forward)) {
1622   constexpr char kHlo[] = R"(
1623 HloModule TestModule
1624 
1625 ENTRY Test {
1626   %arg0 = c64[3,56,56,16] parameter(0)
1627   %arg1 = c64[3,3,3,64] parameter(1)
1628   ROOT %conv = c64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
1629 })";
1630   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1631 }
1632 
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF32ForwardReversed))1633 XLA_TEST_F(ConvolutionHloTest,
1634            DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) {
1635   constexpr char kHlo[] = R"(
1636 HloModule TestModule
1637 
1638 ENTRY Test {
1639   %arg0 = f32[3,56,56,16] parameter(0)
1640   %arg1 = f32[3,3,3,32] parameter(1)
1641   ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf
1642 })";
1643   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1644 }
1645 
1646 // double datatype is not yet supported in ROCm
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF64BackwardFilter))1647 XLA_TEST_F(ConvolutionHloTest,
1648            DISABLED_ON_GPU_ROCM(ConvolveF64BackwardFilter)) {
1649   constexpr char kHlo[] = R"(
1650 HloModule TestModule
1651 
1652 ENTRY Test {
1653   %arg0 = f64[2,5,8,1] parameter(0)
1654   %arg1 = f64[2,5,8,2] parameter(1)
1655   ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf
1656 })";
1657   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1658 }
1659 
1660 // double datatype is not yet supported in ROCm
XLA_TEST_F(ConvolutionHloTest,DISABLED_ON_GPU_ROCM (ConvolveF64BackwardInput))1661 XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64BackwardInput)) {
1662   constexpr char kHlo[] = R"(
1663 HloModule TestModule
1664 
1665 ENTRY Test {
1666   %output = f64[4,5,16,16] parameter(0)
1667   %kernel = f64[5,3,7,7] parameter(1)
1668   %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3}
1669   ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01
1670 })";
1671   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
1672 }
1673 
XLA_TEST_F(ConvolutionHloTest,ConvolveBackwardInput)1674 XLA_TEST_F(ConvolutionHloTest, ConvolveBackwardInput) {
1675   constexpr char kHlo[] = R"(
1676 HloModule TestModule
1677 
1678 ENTRY Test {
1679   %output = f32[3,3,64,64] parameter(0)
1680   %kernel = f32[672,7,7,64] parameter(1)
1681   %reverse = f32[672,7,7,64]{3,2,1,0} reverse(f32[672,7,7,64]{3,2,1,0} %kernel), dimensions={1,2}
1682   ROOT %convolution = f32[672,9,9,64]{3,2,1,0} convolution(f32[3,3,64,64]{3,2,1,0} %output, f32[672,7,7,64]{3,2,1,0} %reverse), window={size=7x7 pad=6_6x6_6}, dim_labels=01bf_o01i->f01b
1683 })";
1684   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1685 }
1686 
XLA_TEST_F(ConvolutionHloTest,SwappedOperandConvolve)1687 XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve) {
1688   constexpr char kHlo[] = R"(
1689 HloModule TestModule
1690 
1691 ENTRY Test {
1692   %lhs = f32[3,3,7,7] parameter(0)
1693   %rhs = f32[5,11,11,7] parameter(1)
1694   ROOT %convolution = f32[5,21,2,7] convolution(lhs, rhs),
1695      window={size=11x11 pad=3_25x3_6},
1696      dim_labels=01bf_o01i->f01b
1697 })";
1698   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1699 }
1700 
XLA_TEST_F(ConvolutionHloTest,SwappedOperandConvolveWithStride)1701 XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolveWithStride) {
1702   constexpr char kHlo[] = R"(
1703 HloModule TestModule
1704 
1705 ENTRY Test {
1706   %lhs = f32[3,3,7,7] parameter(0)
1707   %rhs = f32[5,11,11,7] parameter(1)
1708   ROOT %convolution = f32[5,11,2,7] convolution(lhs, rhs),
1709      window={size=11x11 pad=3_26x3_6 stride=2x1},
1710      dim_labels=01bf_o01i->f01b
1711 })";
1712   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1713 }
XLA_TEST_F(ConvolutionHloTest,SwappedOperandConvolve2)1714 XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve2) {
1715   constexpr char kHlo[] = R"(
1716 HloModule TestModule
1717 
1718 ENTRY Test {
1719   %lhs = f32[3,3,7,7] parameter(0)
1720   %rhs = f32[5,11,11,7] parameter(1)
1721   ROOT %convolution = f32[5,11,4,7] convolution(lhs, rhs),
1722      window={size=11x11 pad=3_25x3_6 lhs_dilate=1x2 rhs_dilate=2x1},
1723      dim_labels=01bf_o01i->f01b
1724 })";
1725   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1726 }
1727 
XLA_TEST_F(ConvolutionHloTest,TestConv0D)1728 XLA_TEST_F(ConvolutionHloTest, TestConv0D) {
1729   constexpr char kHlo[] = R"(
1730 HloModule TestModule
1731 
1732 ENTRY TestComputation {
1733   %parameter.1 = f32[10,5]{1,0} parameter(0)
1734   %parameter.2 = f32[5,7]{1,0} parameter(1)
1735   ROOT %convolution.3 = f32[10,7]{1,0} convolution(f32[10,5]{1,0} %parameter.1, f32[5,7]{1,0} %parameter.2), dim_labels=bf_io->bf
1736 })";
1737   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1738 }
1739 
XLA_TEST_F(ConvolutionHloTest,TestFusedConv2D)1740 XLA_TEST_F(ConvolutionHloTest, TestFusedConv2D) {
1741   constexpr char kHlo[] = R"(
1742 HloModule TestModule
1743 
1744 ENTRY TestComputation {
1745   %p0 = f32[8,5,5,1] parameter(0)
1746   %p1 = f32[3,3,1,32] parameter(1)
1747   %conv = f32[8,5,5,32] convolution(p0, p1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
1748   %bias = f32[32] parameter(2)
1749   %broadcasted_bias = f32[8,5,5,32] broadcast(%bias), dimensions={3}
1750   %add = f32[8,5,5,32] add(%conv, %broadcasted_bias)
1751   %zero = f32[] constant(0)
1752   %zeros = f32[8,5,5,32] broadcast(%zero), dimensions={}
1753   ROOT relu = f32[8,5,5,32] maximum(%zeros, %add)
1754 })";
1755   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1756 }
1757 
XLA_TEST_F(ConvolutionHloTest,TestFusedConv3D)1758 XLA_TEST_F(ConvolutionHloTest, TestFusedConv3D) {
1759   constexpr char kHlo[] = R"(
1760 HloModule TestModule
1761 
1762 ENTRY TestComputation {
1763   %p0 = f32[8,4,5,5,1] parameter(0)
1764   %p1 = f32[3,3,3,1,32] parameter(1)
1765   %conv = f32[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f
1766   %bias = f32[32] parameter(2)
1767   %broadcasted_bias = f32[8,4,5,5,32] broadcast(%bias), dimensions={4}
1768   %add = f32[8,4,5,5,32] add(%conv, %broadcasted_bias)
1769   %zero = f32[] constant(0)
1770   %zeros = f32[8,4,5,5,32] broadcast(%zero), dimensions={}
1771   ROOT relu = f32[8,4,5,5,32] maximum(%zeros, %add)
1772 })";
1773   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1774 }
1775 
XLA_TEST_F(ConvolutionHloTest,TestBooleanInput)1776 XLA_TEST_F(ConvolutionHloTest, TestBooleanInput) {
1777   constexpr char kHlo[] = R"(
1778 HloModule TestModule
1779 
1780 ENTRY TestComputation {
1781   constant.1 = pred[] constant(true)
1782   broadcast.2 = pred[3,3,3]{2,1,0} broadcast(constant.1), dimensions={}
1783   convolution.3 = pred[3,3,3]{2,1,0} convolution(broadcast.2, broadcast.2), window={size=3 pad=1_1}, dim_labels=bf0_oi0->bf0
1784   ROOT tuple.4 = (pred[3,3,3]{2,1,0}) tuple(convolution.3)
1785 })";
1786   EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
1787 }
1788 
1789 }  // namespace
1790 }  // namespace xla
1791