• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace xla {
32 namespace {
33 
34 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
35 // Tests both F32 and BF16.
36 static std::array<bool, 2> use_bfloat16_params{false, true};
37 #else
38 // Only tests F32.
39 static std::array<bool, 1> use_bfloat16_params{false};
40 #endif
41 
42 struct ReverseSpec {
43   std::vector<int64> input_dims;
44   std::vector<int64> reversal;
45   bool use_bfloat16;
46 
ToTestCaseNamexla::__anonab88138b0111::ReverseSpec47   string ToTestCaseName() const {
48     return absl::StrFormat(
49         "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"),
50         absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32");
51   }
52 };
53 
GetTestCases()54 static std::vector<ReverseSpec> GetTestCases() {
55   // clang-format off
56   return ExpandUseBfloat16<ReverseSpec>(
57       use_bfloat16_params,
58       {{{}, {}},
59         {{0, 0}, {0, 1}},
60         {{0, 1}, {0, 1}},
61         {{1, 0}, {0, 1}},
62         {{1, 1}, {0, 1}},
63         {{2, 0, 4, 3}, {0, 2}},
64         {{2, 0, 4, 3}, {1, 3}},
65         {{1, 2, 3, 4}, {0, 3}},
66         {{4, 3, 2, 1}, {0, 1}},
67       });
68   // clang-format on
69 }
70 
PrintTo(const ReverseSpec & spec,std::ostream * os)71 void PrintTo(const ReverseSpec& spec, std::ostream* os) {
72   *os << spec.ToTestCaseName();
73 }
74 
75 class FloatReverseTest : public ClientLibraryTestBase,
76                          public ::testing::WithParamInterface<ReverseSpec> {
77  public:
FloatReverseTest()78   FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); }
79 };
80 
TEST_P(FloatReverseTest,Reverses)81 TEST_P(FloatReverseTest, Reverses) {
82   const ReverseSpec& spec = GetParam();
83   std::vector<float> input_vector(
84       ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
85   std::iota(input_vector.begin(), input_vector.end(), 0.0);
86   auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
87   auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
88 
89   XlaBuilder builder(TestName());
90   auto a = AddParam(input_literal, &builder);
91   Rev(a, spec.reversal);
92 
93   Literal expected = input_literal.Clone();
94   std::vector<int64> output_indices(spec.input_dims.size());
95   expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
96     for (int64 i = 0; i < indices.size(); ++i) {
97       output_indices[i] = indices[i];
98     }
99     float value = input_literal.Get<float>(indices);
100     for (int64 dim : spec.reversal) {
101       output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
102     }
103     expected.Set<float>(output_indices, value);
104   });
105   ComputeAndCompareLiteral(&builder, expected, {});
106 }
107 
108 INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
109                         ::testing::ValuesIn(GetTestCases()),
110                         ::testing::PrintToStringParamName());
111 
112 // A simple test class which not templated by float precision.
113 class ReverseTest : public ClientLibraryTestBase {};
114 
115 // Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
XLA_TEST_F(ReverseTest,Reverse4DU8ArrayOnDim23)116 XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
117   XlaBuilder b(TestName());
118   // Input shape is U8[1x2x3x4].
119   // clang-format off
120   Array4D<uint8> input({{
121     {{1, 2, 3, 4},
122      {5, 6, 7, 8},
123      {9, 10, 11, 12}},
124     {{13, 14, 15, 16},
125      {17, 18, 19, 20},
126      {21, 22, 23, 24}},
127   }});
128   // clang-format on
129 
130   Rev(ConstantR4FromArray4D<uint8>(&b, input), {0, 3});
131 
132   // clang-format off
133   Array4D<uint8> expected({{
134     {{4, 3, 2, 1},
135      {8, 7, 6, 5},
136      {12, 11, 10, 9}},
137     {{16, 15, 14, 13},
138      {20, 19, 18, 17},
139      {24, 23, 22, 21}},
140   }});
141   // clang-format on
142   ComputeAndCompareR4<uint8>(&b, expected, {});
143 }
144 
145 // Tests the reverse operation on a 4D float array on dimension 0 and 1.
TEST_F(ReverseTest,Reverse4DFloatArrayOnDim01)146 TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
147   XlaBuilder b(TestName());
148   // Input shape is float[4x3x2x1].
149   // clang-format off
150   Array4D<float> input({
151     {{{1.0f}, {2.0f}},
152      {{3.0f}, {4.0f}},
153      {{5.0f}, {6.0f}}},
154     {{{7.0f}, {8.0f}},
155      {{9.0f}, {10.0f}},
156      {{11.0f}, {12.0f}}},
157     {{{13.0f}, {14.0f}},
158      {{15.0f}, {16.0f}},
159      {{17.0f}, {18.0f}}},
160     {{{19.0f}, {20.0f}},
161      {{21.0f}, {22.0f}},
162      {{23.0f}, {24.0f}}},
163   });
164   // clang-format on
165 
166   Rev(ConstantR4FromArray4D<float>(&b, input), {0, 1});
167 
168   // clang-format off
169   Array4D<float> expected({
170     {{{23.0f}, {24.0f}},
171      {{21.0f}, {22.0f}},
172      {{19.0f}, {20.0f}}},
173     {{{17.0f}, {18.0f}},
174      {{15.0f}, {16.0f}},
175      {{13.0f}, {14.0f}}},
176     {{{11.0f}, {12.0f}},
177      {{9.0f}, {10.0f}},
178      {{7.0f}, {8.0f}}},
179     {{{5.0f}, {6.0f}},
180      {{3.0f}, {4.0f}},
181      {{1.0f}, {2.0f}}},
182   });
183   // clang-format on
184   ComputeAndCompareR4<float>(&b, expected, {}, ErrorSpec(0.0001));
185 }
186 
187 }  // namespace
188 }  // namespace xla
189