1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace xla {
32 namespace {
33
34 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
35 // Tests both F32 and BF16.
36 static std::array<bool, 2> use_bfloat16_params{false, true};
37 #else
38 // Only tests F32.
39 static std::array<bool, 1> use_bfloat16_params{false};
40 #endif
41
42 struct ReverseSpec {
43 std::vector<int64> input_dims;
44 std::vector<int64> reversal;
45 bool use_bfloat16;
46
ToTestCaseNamexla::__anonab88138b0111::ReverseSpec47 string ToTestCaseName() const {
48 return absl::StrFormat(
49 "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"),
50 absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32");
51 }
52 };
53
GetTestCases()54 static std::vector<ReverseSpec> GetTestCases() {
55 // clang-format off
56 return ExpandUseBfloat16<ReverseSpec>(
57 use_bfloat16_params,
58 {{{}, {}},
59 {{0, 0}, {0, 1}},
60 {{0, 1}, {0, 1}},
61 {{1, 0}, {0, 1}},
62 {{1, 1}, {0, 1}},
63 {{2, 0, 4, 3}, {0, 2}},
64 {{2, 0, 4, 3}, {1, 3}},
65 {{1, 2, 3, 4}, {0, 3}},
66 {{4, 3, 2, 1}, {0, 1}},
67 });
68 // clang-format on
69 }
70
PrintTo(const ReverseSpec & spec,std::ostream * os)71 void PrintTo(const ReverseSpec& spec, std::ostream* os) {
72 *os << spec.ToTestCaseName();
73 }
74
75 class FloatReverseTest : public ClientLibraryTestBase,
76 public ::testing::WithParamInterface<ReverseSpec> {
77 public:
FloatReverseTest()78 FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); }
79 };
80
TEST_P(FloatReverseTest,Reverses)81 TEST_P(FloatReverseTest, Reverses) {
82 const ReverseSpec& spec = GetParam();
83 std::vector<float> input_vector(
84 ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
85 std::iota(input_vector.begin(), input_vector.end(), 0.0);
86 auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
87 auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
88
89 XlaBuilder builder(TestName());
90 auto a = AddParam(input_literal, &builder);
91 Rev(a, spec.reversal);
92
93 Literal expected = input_literal.Clone();
94 std::vector<int64> output_indices(spec.input_dims.size());
95 expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
96 for (int64 i = 0; i < indices.size(); ++i) {
97 output_indices[i] = indices[i];
98 }
99 float value = input_literal.Get<float>(indices);
100 for (int64 dim : spec.reversal) {
101 output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
102 }
103 expected.Set<float>(output_indices, value);
104 });
105 ComputeAndCompareLiteral(&builder, expected, {});
106 }
107
108 INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
109 ::testing::ValuesIn(GetTestCases()),
110 ::testing::PrintToStringParamName());
111
112 // A simple test class which not templated by float precision.
113 class ReverseTest : public ClientLibraryTestBase {};
114
115 // Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
XLA_TEST_F(ReverseTest,Reverse4DU8ArrayOnDim23)116 XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
117 XlaBuilder b(TestName());
118 // Input shape is U8[1x2x3x4].
119 // clang-format off
120 Array4D<uint8> input({{
121 {{1, 2, 3, 4},
122 {5, 6, 7, 8},
123 {9, 10, 11, 12}},
124 {{13, 14, 15, 16},
125 {17, 18, 19, 20},
126 {21, 22, 23, 24}},
127 }});
128 // clang-format on
129
130 Rev(ConstantR4FromArray4D<uint8>(&b, input), {0, 3});
131
132 // clang-format off
133 Array4D<uint8> expected({{
134 {{4, 3, 2, 1},
135 {8, 7, 6, 5},
136 {12, 11, 10, 9}},
137 {{16, 15, 14, 13},
138 {20, 19, 18, 17},
139 {24, 23, 22, 21}},
140 }});
141 // clang-format on
142 ComputeAndCompareR4<uint8>(&b, expected, {});
143 }
144
145 // Tests the reverse operation on a 4D float array on dimension 0 and 1.
TEST_F(ReverseTest,Reverse4DFloatArrayOnDim01)146 TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
147 XlaBuilder b(TestName());
148 // Input shape is float[4x3x2x1].
149 // clang-format off
150 Array4D<float> input({
151 {{{1.0f}, {2.0f}},
152 {{3.0f}, {4.0f}},
153 {{5.0f}, {6.0f}}},
154 {{{7.0f}, {8.0f}},
155 {{9.0f}, {10.0f}},
156 {{11.0f}, {12.0f}}},
157 {{{13.0f}, {14.0f}},
158 {{15.0f}, {16.0f}},
159 {{17.0f}, {18.0f}}},
160 {{{19.0f}, {20.0f}},
161 {{21.0f}, {22.0f}},
162 {{23.0f}, {24.0f}}},
163 });
164 // clang-format on
165
166 Rev(ConstantR4FromArray4D<float>(&b, input), {0, 1});
167
168 // clang-format off
169 Array4D<float> expected({
170 {{{23.0f}, {24.0f}},
171 {{21.0f}, {22.0f}},
172 {{19.0f}, {20.0f}}},
173 {{{17.0f}, {18.0f}},
174 {{15.0f}, {16.0f}},
175 {{13.0f}, {14.0f}}},
176 {{{11.0f}, {12.0f}},
177 {{9.0f}, {10.0f}},
178 {{7.0f}, {8.0f}}},
179 {{{5.0f}, {6.0f}},
180 {{3.0f}, {4.0f}},
181 {{1.0f}, {2.0f}}},
182 });
183 // clang-format on
184 ComputeAndCompareR4<float>(&b, expected, {}, ErrorSpec(0.0001));
185 }
186
187 } // namespace
188 } // namespace xla
189