1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array4d.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/platform/types.h"
29
30 namespace xla {
31 namespace {
32
33 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
34 // Tests both F32 and BF16.
35 static std::array<bool, 2> use_bfloat16_params{false, true};
36 #else
37 // Only tests F32.
38 static std::array<bool, 1> use_bfloat16_params{false};
39 #endif
40
41 struct ReverseSpec {
42 absl::Span<const int64> input_dims;
43 absl::Span<const int64> reversal;
44 bool use_bfloat16;
45
ToTestCaseNamexla::__anon73f3522d0111::ReverseSpec46 string ToTestCaseName() const {
47 return absl::StrFormat(
48 "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"),
49 absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32");
50 }
51 };
52
GetTestCases()53 static std::vector<ReverseSpec> GetTestCases() {
54 // clang-format off
55 return ExpandUseBfloat16<ReverseSpec>(
56 use_bfloat16_params,
57 {{{}, {}},
58 {{0, 0}, {0, 1}},
59 {{0, 1}, {0, 1}},
60 {{1, 0}, {0, 1}},
61 {{1, 1}, {0, 1}},
62 {{2, 0, 4, 3}, {0, 2}},
63 {{2, 0, 4, 3}, {1, 3}},
64 {{1, 2, 3, 4}, {0, 3}},
65 {{4, 3, 2, 1}, {0, 1}},
66 });
67 // clang-format on
68 }
69
PrintTo(const ReverseSpec & spec,std::ostream * os)70 void PrintTo(const ReverseSpec& spec, std::ostream* os) {
71 *os << spec.ToTestCaseName();
72 }
73
74 class FloatReverseTest : public ClientLibraryTestBase,
75 public ::testing::WithParamInterface<ReverseSpec> {
76 public:
FloatReverseTest()77 FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); }
78 };
79
TEST_P(FloatReverseTest,Reverses)80 TEST_P(FloatReverseTest, Reverses) {
81 const ReverseSpec& spec = GetParam();
82 std::vector<float> input_vector(
83 ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
84 std::iota(input_vector.begin(), input_vector.end(), 0.0);
85 auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
86 auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
87
88 XlaBuilder builder(TestName());
89 auto a = AddParam(input_literal, &builder);
90 Rev(a, spec.reversal);
91
92 Literal expected = input_literal.Clone();
93 std::vector<int64> output_indices(spec.input_dims.size());
94 expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
95 for (int64 i = 0; i < indices.size(); ++i) {
96 output_indices[i] = indices[i];
97 }
98 float value = input_literal.Get<float>(indices);
99 for (int64 dim : spec.reversal) {
100 output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
101 }
102 expected.Set<float>(output_indices, value);
103 });
104 ComputeAndCompareLiteral(&builder, expected, {});
105 }
106
107 INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
108 ::testing::ValuesIn(GetTestCases()),
109 ::testing::PrintToStringParamName());
110
111 // A simple test class which not templated by float precision.
112 class ReverseTest : public ClientLibraryTestBase {};
113
114 // Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
XLA_TEST_F(ReverseTest,Reverse4DU8ArrayOnDim23)115 XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
116 XlaBuilder b(TestName());
117 // Input shape is U8[1x2x3x4].
118 // clang-format off
119 Array4D<uint8> input({{
120 {{1, 2, 3, 4},
121 {5, 6, 7, 8},
122 {9, 10, 11, 12}},
123 {{13, 14, 15, 16},
124 {17, 18, 19, 20},
125 {21, 22, 23, 24}},
126 }});
127 // clang-format on
128
129 Rev(ConstantR4FromArray4D<uint8>(&b, input), {0, 3});
130
131 // clang-format off
132 Array4D<uint8> expected({{
133 {{4, 3, 2, 1},
134 {8, 7, 6, 5},
135 {12, 11, 10, 9}},
136 {{16, 15, 14, 13},
137 {20, 19, 18, 17},
138 {24, 23, 22, 21}},
139 }});
140 // clang-format on
141 ComputeAndCompareR4<uint8>(&b, expected, {});
142 }
143
144 // Tests the reverse operation on a 4D float array on dimension 0 and 1.
TEST_F(ReverseTest,Reverse4DFloatArrayOnDim01)145 TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
146 XlaBuilder b(TestName());
147 // Input shape is float[4x3x2x1].
148 // clang-format off
149 Array4D<float> input({
150 {{{1.0f}, {2.0f}},
151 {{3.0f}, {4.0f}},
152 {{5.0f}, {6.0f}}},
153 {{{7.0f}, {8.0f}},
154 {{9.0f}, {10.0f}},
155 {{11.0f}, {12.0f}}},
156 {{{13.0f}, {14.0f}},
157 {{15.0f}, {16.0f}},
158 {{17.0f}, {18.0f}}},
159 {{{19.0f}, {20.0f}},
160 {{21.0f}, {22.0f}},
161 {{23.0f}, {24.0f}}},
162 });
163 // clang-format on
164
165 Rev(ConstantR4FromArray4D<float>(&b, input), {0, 1});
166
167 // clang-format off
168 Array4D<float> expected({
169 {{{23.0f}, {24.0f}},
170 {{21.0f}, {22.0f}},
171 {{19.0f}, {20.0f}}},
172 {{{17.0f}, {18.0f}},
173 {{15.0f}, {16.0f}},
174 {{13.0f}, {14.0f}}},
175 {{{11.0f}, {12.0f}},
176 {{9.0f}, {10.0f}},
177 {{7.0f}, {8.0f}}},
178 {{{5.0f}, {6.0f}},
179 {{3.0f}, {4.0f}},
180 {{1.0f}, {2.0f}}},
181 });
182 // clang-format on
183 ComputeAndCompareR4<float>(&b, expected, {}, ErrorSpec(0.0001));
184 }
185
186 } // namespace
187 } // namespace xla
188