• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array4d.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 namespace {
32 
33 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
34 // Tests both F32 and BF16.
35 static std::array<bool, 2> use_bfloat16_params{false, true};
36 #else
37 // Only tests F32.
38 static std::array<bool, 1> use_bfloat16_params{false};
39 #endif
40 
41 struct ReverseSpec {
42   absl::Span<const int64> input_dims;
43   absl::Span<const int64> reversal;
44   bool use_bfloat16;
45 
ToTestCaseNamexla::__anon73f3522d0111::ReverseSpec46   string ToTestCaseName() const {
47     return absl::StrFormat(
48         "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"),
49         absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32");
50   }
51 };
52 
GetTestCases()53 static std::vector<ReverseSpec> GetTestCases() {
54   // clang-format off
55   return ExpandUseBfloat16<ReverseSpec>(
56       use_bfloat16_params,
57       {{{}, {}},
58         {{0, 0}, {0, 1}},
59         {{0, 1}, {0, 1}},
60         {{1, 0}, {0, 1}},
61         {{1, 1}, {0, 1}},
62         {{2, 0, 4, 3}, {0, 2}},
63         {{2, 0, 4, 3}, {1, 3}},
64         {{1, 2, 3, 4}, {0, 3}},
65         {{4, 3, 2, 1}, {0, 1}},
66       });
67   // clang-format on
68 }
69 
PrintTo(const ReverseSpec & spec,std::ostream * os)70 void PrintTo(const ReverseSpec& spec, std::ostream* os) {
71   *os << spec.ToTestCaseName();
72 }
73 
74 class FloatReverseTest : public ClientLibraryTestBase,
75                          public ::testing::WithParamInterface<ReverseSpec> {
76  public:
FloatReverseTest()77   FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); }
78 };
79 
TEST_P(FloatReverseTest,Reverses)80 TEST_P(FloatReverseTest, Reverses) {
81   const ReverseSpec& spec = GetParam();
82   std::vector<float> input_vector(
83       ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
84   std::iota(input_vector.begin(), input_vector.end(), 0.0);
85   auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
86   auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
87 
88   XlaBuilder builder(TestName());
89   auto a = AddParam(input_literal, &builder);
90   Rev(a, spec.reversal);
91 
92   Literal expected = input_literal.Clone();
93   std::vector<int64> output_indices(spec.input_dims.size());
94   expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
95     for (int64 i = 0; i < indices.size(); ++i) {
96       output_indices[i] = indices[i];
97     }
98     float value = input_literal.Get<float>(indices);
99     for (int64 dim : spec.reversal) {
100       output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
101     }
102     expected.Set<float>(output_indices, value);
103   });
104   ComputeAndCompareLiteral(&builder, expected, {});
105 }
106 
107 INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
108                         ::testing::ValuesIn(GetTestCases()),
109                         ::testing::PrintToStringParamName());
110 
111 // A simple test class which not templated by float precision.
112 class ReverseTest : public ClientLibraryTestBase {};
113 
114 // Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
XLA_TEST_F(ReverseTest,Reverse4DU8ArrayOnDim23)115 XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
116   XlaBuilder b(TestName());
117   // Input shape is U8[1x2x3x4].
118   // clang-format off
119   Array4D<uint8> input({{
120     {{1, 2, 3, 4},
121      {5, 6, 7, 8},
122      {9, 10, 11, 12}},
123     {{13, 14, 15, 16},
124      {17, 18, 19, 20},
125      {21, 22, 23, 24}},
126   }});
127   // clang-format on
128 
129   Rev(ConstantR4FromArray4D<uint8>(&b, input), {0, 3});
130 
131   // clang-format off
132   Array4D<uint8> expected({{
133     {{4, 3, 2, 1},
134      {8, 7, 6, 5},
135      {12, 11, 10, 9}},
136     {{16, 15, 14, 13},
137      {20, 19, 18, 17},
138      {24, 23, 22, 21}},
139   }});
140   // clang-format on
141   ComputeAndCompareR4<uint8>(&b, expected, {});
142 }
143 
144 // Tests the reverse operation on a 4D float array on dimension 0 and 1.
TEST_F(ReverseTest,Reverse4DFloatArrayOnDim01)145 TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
146   XlaBuilder b(TestName());
147   // Input shape is float[4x3x2x1].
148   // clang-format off
149   Array4D<float> input({
150     {{{1.0f}, {2.0f}},
151      {{3.0f}, {4.0f}},
152      {{5.0f}, {6.0f}}},
153     {{{7.0f}, {8.0f}},
154      {{9.0f}, {10.0f}},
155      {{11.0f}, {12.0f}}},
156     {{{13.0f}, {14.0f}},
157      {{15.0f}, {16.0f}},
158      {{17.0f}, {18.0f}}},
159     {{{19.0f}, {20.0f}},
160      {{21.0f}, {22.0f}},
161      {{23.0f}, {24.0f}}},
162   });
163   // clang-format on
164 
165   Rev(ConstantR4FromArray4D<float>(&b, input), {0, 1});
166 
167   // clang-format off
168   Array4D<float> expected({
169     {{{23.0f}, {24.0f}},
170      {{21.0f}, {22.0f}},
171      {{19.0f}, {20.0f}}},
172     {{{17.0f}, {18.0f}},
173      {{15.0f}, {16.0f}},
174      {{13.0f}, {14.0f}}},
175     {{{11.0f}, {12.0f}},
176      {{9.0f}, {10.0f}},
177      {{7.0f}, {8.0f}}},
178     {{{5.0f}, {6.0f}},
179      {{3.0f}, {4.0f}},
180      {{1.0f}, {2.0f}}},
181   });
182   // clang-format on
183   ComputeAndCompareR4<float>(&b, expected, {}, ErrorSpec(0.0001));
184 }
185 
186 }  // namespace
187 }  // namespace xla
188