• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/types/optional.h"
17 #include "tensorflow/compiler/xla/client/xla_computation.h"
18 #include "tensorflow/compiler/xla/execution_options_util.h"
19 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
20 #include "tensorflow/compiler/xla/service/despecializer.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 
28 namespace xla {
29 namespace {
30 
GetFloatDataType(bool use_bfloat16)31 string GetFloatDataType(bool use_bfloat16) {
32   return use_bfloat16 ? "bf16" : "f32";
33 }
34 
35 struct DepthwiseConvolution2DSpec {
36   int64 output_feature, window, stride, pad, lhs_dilate;
37   std::vector<int64> activation_dims;
38   std::vector<int64> activation_layout;
39   std::vector<int64> kernel_dims;
40   std::vector<int64> kernel_layout;
41   std::vector<int64> output_dims;
42   std::vector<int64> output_layout;
43 };
44 
45 class DepthwiseConvolution2DTest
46     : public HloTestBase,
47       public ::testing::WithParamInterface<
48           ::testing::tuple<DepthwiseConvolution2DSpec, bool>> {};
49 
GetConv2DTestCases()50 static std::vector<DepthwiseConvolution2DSpec> GetConv2DTestCases() {
51   std::vector<DepthwiseConvolution2DSpec> config_set;
52   std::vector<std::vector<int64>> config_options = {
53       {128, 6, 3, 64},  {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64},
54       {144, 5, 2, 256}, {8, 48, 17, 8},   {128, 20, 6, 64}, {64, 14, 12, 172},
55       {16, 9, 4, 16},   {128, 1, 2, 144}, {256, 1, 2, 64}};
56 
57   for (auto option : config_options) {
58     int64 feature = option[0];
59     int64 activation_size = option[1];
60     int64 kernel_size = option[2];
61     int64 batch = option[3];
62 
63     std::vector<int64> kernel_layout = {3, 2, 1, 0};
64     DepthwiseConvolution2DSpec config;
65     config.output_feature = feature;
66     config.window = kernel_size;
67 
68     config.activation_dims = {batch, activation_size, activation_size, feature};
69     config.activation_layout = {3, 0, 2, 1};
70 
71     config.kernel_dims = {kernel_size, kernel_size, 1, feature};
72     config.kernel_layout = {3, 2, 1, 0};
73 
74     if (activation_size == 1 && kernel_size == 2) {
75       // Test for outer dim.
76       config.output_dims = {batch, activation_size + kernel_size - 1,
77                             activation_size + kernel_size, feature};
78     } else if (feature == 256) {
79       // Restrict dilation-based tests only to one feature configuration.
80       config.stride = activation_size - 1;
81       config.pad = 0;
82       config.lhs_dilate = feature / 32;
83       config.output_dims = {batch, feature / 32,
84                             activation_size - kernel_size + 1, feature};
85     } else {
86       config.stride = config.pad = config.lhs_dilate = -1;
87       config.output_dims = {batch, activation_size - kernel_size + 1,
88                             activation_size - kernel_size + 1, feature};
89     }
90 
91     // Try this layout for all kernel shapes.
92     config.output_layout = {3, 0, 2, 1};
93     config_set.push_back(config);
94 
95     // Try other layouts only for certain kernel shapes.
96     if (kernel_size % 2 == 0) {
97       config.activation_layout = {0, 3, 2, 1};
98       config_set.push_back(config);
99 
100       config.output_layout = {0, 3, 2, 1};
101       config_set.push_back(config);
102 
103       config.activation_layout = {3, 0, 2, 1};
104       config_set.push_back(config);
105     }
106   }
107 
108   return config_set;
109 }
110 
DepthwiseConvolution2DTestDataToString(const::testing::TestParamInfo<::testing::tuple<DepthwiseConvolution2DSpec,bool>> & data)111 string DepthwiseConvolution2DTestDataToString(
112     const ::testing::TestParamInfo<
113         ::testing::tuple<DepthwiseConvolution2DSpec, bool>>& data) {
114   const auto& spec = ::testing::get<0>(data.param);
115   const string data_type = GetFloatDataType(::testing::get<1>(data.param));
116   string str = absl::StrCat(
117       "activation_dims_", absl::StrJoin(spec.activation_dims, "x"),
118       "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"),
119       "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_",
120       absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_",
121       absl::StrJoin(spec.output_dims, "x"), "_output_layout_",
122       absl::StrJoin(spec.output_layout, "_"), data_type);
123   // -1 indicates non-existence.
124   if (spec.stride != -1) {
125     absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1");
126   }
127 
128   // Test names are not allowed to contain the '-' character.
129   absl::c_replace(str, '-', 'n');
130   return str;
131 }
132 
BuildHloTextDepthwiseConvolution2D(const DepthwiseConvolution2DSpec & spec,bool use_bfloat16)133 string BuildHloTextDepthwiseConvolution2D(
134     const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) {
135   const string data_type = GetFloatDataType(use_bfloat16);
136   if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) {
137     return absl::StrFormat(
138         R"(
139     HloModule TensorFlowDepthwiseConv
140 
141     ENTRY main {
142       activation = %s[%s]{%s} parameter(0)
143       kernel = %s[%s]{%s} parameter(1)
144       ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
145           window={size=%dx%d  pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f,
146           feature_group_count=%d
147     }
148     )",
149         data_type, absl::StrJoin(spec.activation_dims, ","),
150         absl::StrJoin(spec.activation_layout, ","), data_type,
151         absl::StrJoin(spec.kernel_dims, ","),
152         absl::StrJoin(spec.kernel_layout, ","), data_type,
153         absl::StrJoin(spec.output_dims, ","),
154         absl::StrJoin(spec.output_layout, ","), data_type,
155         absl::StrJoin(spec.activation_dims, ","),
156         absl::StrJoin(spec.activation_layout, ","), data_type,
157         absl::StrJoin(spec.kernel_dims, ","),
158         absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
159         spec.window, spec.window, spec.window, spec.output_feature);
160 
161   } else if (spec.stride == -1) {
162     return absl::StrFormat(
163         R"(
164       HloModule TensorFlowDepthwiseConv
165 
166       ENTRY main {
167         activation = %s[%s]{%s} parameter(0)
168         kernel = %s[%s]{%s} parameter(1)
169         ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
170             window={size=%dx%d}, dim_labels=b01f_01io->b01f,
171             feature_group_count=%d
172       }
173       )",
174         data_type, absl::StrJoin(spec.activation_dims, ","),
175         absl::StrJoin(spec.activation_layout, ","), data_type,
176         absl::StrJoin(spec.kernel_dims, ","),
177         absl::StrJoin(spec.kernel_layout, ","), data_type,
178         absl::StrJoin(spec.output_dims, ","),
179         absl::StrJoin(spec.output_layout, ","), data_type,
180         absl::StrJoin(spec.activation_dims, ","),
181         absl::StrJoin(spec.activation_layout, ","), data_type,
182         absl::StrJoin(spec.kernel_dims, ","),
183         absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
184         spec.output_feature);
185   } else {
186     return absl::StrFormat(
187         R"(
188     HloModule TensorFlowDepthwiseConv
189 
190     ENTRY main {
191       activation = %s[%s]{%s} parameter(0)
192       kernel = %s[%s]{%s} parameter(1)
193       ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel),
194           window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1},
195           dim_labels=b01f_01io->b01f, feature_group_count=%d
196     }
197     )",
198         data_type, absl::StrJoin(spec.activation_dims, ","),
199         absl::StrJoin(spec.activation_layout, ","), data_type,
200         absl::StrJoin(spec.kernel_dims, ","),
201         absl::StrJoin(spec.kernel_layout, ","), data_type,
202         absl::StrJoin(spec.output_dims, ","),
203         absl::StrJoin(spec.output_layout, ","), data_type,
204         absl::StrJoin(spec.activation_dims, ","),
205         absl::StrJoin(spec.activation_layout, ","), data_type,
206         absl::StrJoin(spec.kernel_dims, ","),
207         absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window,
208         spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature);
209   }
210 }
211 
XLA_TEST_P(DepthwiseConvolution2DTest,DoIt)212 XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) {
213   const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam());
214   bool use_bfloat16 = ::testing::get<1>(GetParam());
215   const string hlo_text =
216       BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16);
217 
218   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01},
219                             [](HloModule* module) -> Status {
220                               BFloat16MixedPrecisionRemoval remover;
221                               TF_RETURN_IF_ERROR(remover.Run(module).status());
222                               Despecializer despecializer;
223                               return despecializer.Run(module).status();
224                             }));
225 }
226 
227 INSTANTIATE_TEST_CASE_P(
228     DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest,
229     ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()),
230                        ::testing::Bool()),
231     DepthwiseConvolution2DTestDataToString);
232 
233 }  // namespace
234 }  // namespace xla
235