• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <optional>
17 
18 #include "tensorflow/compiler/xla/client/xla_computation.h"
19 #include "tensorflow/compiler/xla/execution_options_util.h"
20 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
21 #include "tensorflow/compiler/xla/service/despecializer.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/conv_depthwise_common.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 
29 namespace xla {
30 namespace {
31 
32 class DepthwiseConvolution2DTest
33     : public HloTestBase,
34       public ::testing::WithParamInterface<
35           ::testing::tuple<DepthwiseConvolution2DSpec, bool>> {};
36 
GetConv2DTestCases()37 static std::vector<DepthwiseConvolution2DSpec> GetConv2DTestCases() {
38   std::vector<DepthwiseConvolution2DSpec> config_set;
39   std::vector<std::vector<int64_t>> config_options = {
40       {128, 6, 3, 64},  {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64},
41       {144, 5, 2, 256}, {8, 48, 17, 8},   {128, 20, 6, 64}, {64, 14, 12, 172},
42       {16, 9, 4, 16},   {128, 1, 2, 144}, {256, 1, 2, 64},  {256, 1, 2, 2},
43       {144, 5, 3, 3},   {8, 48, 17, 1},   {16, 9, 5, 4}};
44 
45   for (auto option : config_options) {
46     int64_t feature = option[0];
47     int64_t activation_size = option[1];
48     int64_t kernel_size = option[2];
49     int64_t batch = option[3];
50 
51     std::vector<int64_t> kernel_layout = {3, 2, 1, 0};
52     DepthwiseConvolution2DSpec config;
53     config.output_feature = feature;
54     config.window = kernel_size;
55 
56     config.activation_dims = {batch, activation_size, activation_size, feature};
57     config.activation_layout = {3, 0, 2, 1};
58 
59     config.kernel_dims = {kernel_size, kernel_size, 1, feature};
60     config.kernel_layout = {3, 2, 1, 0};
61     config.output_layout = {3, 0, 2, 1};
62 
63     if (activation_size == 1 && kernel_size == 2) {
64       config.stride = config.pad = config.lhs_dilate = -1;
65       // Test for outer dim.
66       config.output_dims = {batch, activation_size + kernel_size - 1,
67                             activation_size + kernel_size, feature};
68     } else if (feature == 256) {
69       // Restrict dilation-based tests only to one feature configuration.
70       config.stride = activation_size - 1;
71       config.pad = 0;
72       config.lhs_dilate = feature / 32;
73       config.output_dims = {batch, feature / 32,
74                             activation_size - kernel_size + 1, feature};
75     } else {
76       config.stride = config.pad = config.lhs_dilate = -1;
77       config.output_dims = {batch, activation_size - kernel_size + 1,
78                             activation_size - kernel_size + 1, feature};
79     }
80     config_set.push_back(config);
81   }
82 
83   return config_set;
84 }
85 
XLA_TEST_P(DepthwiseConvolution2DTest,DoIt)86 XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) {
87   const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam());
88   bool use_bfloat16 = ::testing::get<1>(GetParam());
89 
90 #ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
91   if (use_bfloat16) {
92     return;
93   }
94 #endif
95 
96   const std::string hlo_text =
97       BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16);
98 
99   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01},
100                             [](HloModule* module) -> Status {
101                               BFloat16MixedPrecisionRemoval remover;
102                               TF_RETURN_IF_ERROR(remover.Run(module).status());
103                               Despecializer despecializer;
104                               return despecializer.Run(module).status();
105                             }));
106 }
107 
108 INSTANTIATE_TEST_CASE_P(
109     DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest,
110     ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()),
111                        ::testing::Bool()),
112     DepthwiseConvolution2DTestDataToString);
113 
114 }  // namespace
115 }  // namespace xla
116