1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <optional>
17
18 #include "tensorflow/compiler/xla/client/xla_computation.h"
19 #include "tensorflow/compiler/xla/execution_options_util.h"
20 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
21 #include "tensorflow/compiler/xla/service/despecializer.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/conv_depthwise_common.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28
29 namespace xla {
30 namespace {
31
32 class DepthwiseConvolution2DTest
33 : public HloTestBase,
34 public ::testing::WithParamInterface<
35 ::testing::tuple<DepthwiseConvolution2DSpec, bool>> {};
36
GetConv2DTestCases()37 static std::vector<DepthwiseConvolution2DSpec> GetConv2DTestCases() {
38 std::vector<DepthwiseConvolution2DSpec> config_set;
39 std::vector<std::vector<int64_t>> config_options = {
40 {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64},
41 {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172},
42 {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}, {256, 1, 2, 2},
43 {144, 5, 3, 3}, {8, 48, 17, 1}, {16, 9, 5, 4}};
44
45 for (auto option : config_options) {
46 int64_t feature = option[0];
47 int64_t activation_size = option[1];
48 int64_t kernel_size = option[2];
49 int64_t batch = option[3];
50
51 std::vector<int64_t> kernel_layout = {3, 2, 1, 0};
52 DepthwiseConvolution2DSpec config;
53 config.output_feature = feature;
54 config.window = kernel_size;
55
56 config.activation_dims = {batch, activation_size, activation_size, feature};
57 config.activation_layout = {3, 0, 2, 1};
58
59 config.kernel_dims = {kernel_size, kernel_size, 1, feature};
60 config.kernel_layout = {3, 2, 1, 0};
61 config.output_layout = {3, 0, 2, 1};
62
63 if (activation_size == 1 && kernel_size == 2) {
64 config.stride = config.pad = config.lhs_dilate = -1;
65 // Test for outer dim.
66 config.output_dims = {batch, activation_size + kernel_size - 1,
67 activation_size + kernel_size, feature};
68 } else if (feature == 256) {
69 // Restrict dilation-based tests only to one feature configuration.
70 config.stride = activation_size - 1;
71 config.pad = 0;
72 config.lhs_dilate = feature / 32;
73 config.output_dims = {batch, feature / 32,
74 activation_size - kernel_size + 1, feature};
75 } else {
76 config.stride = config.pad = config.lhs_dilate = -1;
77 config.output_dims = {batch, activation_size - kernel_size + 1,
78 activation_size - kernel_size + 1, feature};
79 }
80 config_set.push_back(config);
81 }
82
83 return config_set;
84 }
85
XLA_TEST_P(DepthwiseConvolution2DTest,DoIt)86 XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) {
87 const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam());
88 bool use_bfloat16 = ::testing::get<1>(GetParam());
89
90 #ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
91 if (use_bfloat16) {
92 return;
93 }
94 #endif
95
96 const std::string hlo_text =
97 BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16);
98
99 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01},
100 [](HloModule* module) -> Status {
101 BFloat16MixedPrecisionRemoval remover;
102 TF_RETURN_IF_ERROR(remover.Run(module).status());
103 Despecializer despecializer;
104 return despecializer.Run(module).status();
105 }));
106 }
107
108 INSTANTIATE_TEST_CASE_P(
109 DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest,
110 ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()),
111 ::testing::Bool()),
112 DepthwiseConvolution2DTestDataToString);
113
114 } // namespace
115 } // namespace xla
116