• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace xla {
29 namespace {
30 
31 class UnaryOpTest : public ClientLibraryTestBase {
32  protected:
33   template <typename T>
inf()34   T inf() {
35     return std::numeric_limits<T>::infinity();
36   }
37   template <typename T>
AbsSize0TestHelper()38   void AbsSize0TestHelper() {
39     XlaBuilder builder(TestName());
40     auto arg = ConstantR1<T>(&builder, {});
41     Abs(arg);
42 
43     if (primitive_util::NativeToPrimitiveType<T>() == C64) {
44       ComputeAndCompareR1<float>(&builder, {}, {});
45     } else {
46       ComputeAndCompareR1<T>(&builder, {}, {});
47     }
48   }
49 
50   template <typename T>
AbsTestHelper()51   void AbsTestHelper() {
52     XlaBuilder builder(TestName());
53     auto arg = ConstantR1<T>(&builder, {-2, 25, 0, -123, inf<T>(), -inf<T>()});
54     Abs(arg);
55 
56     ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
57   }
58 
59   template <typename T>
SignTestHelper()60   void SignTestHelper() {
61     XlaBuilder builder(TestName());
62     auto arg = ConstantR1<T>(
63         &builder, {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
64     Sign(arg);
65 
66     ComputeAndCompareR1<T>(
67         &builder,
68         {-1, 1, static_cast<T>(+0.0), static_cast<T>(-0.0), -1, 1, -1}, {});
69   }
70 
71   template <typename T>
SignAbsTestHelper()72   void SignAbsTestHelper() {
73     XlaBuilder builder(TestName());
74     auto arg = ConstantR1<T>(&builder, {-2, 25, 0, -123});
75     auto sign = Sign(arg);
76     auto abs = Abs(arg);
77     Sub(Mul(sign, abs), arg);
78 
79     ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {});
80   }
81 };
82 
83 template <>
inf()84 int UnaryOpTest::inf<int>() {
85   return 2147483647;
86 }
87 
88 template <>
inf()89 int64_t UnaryOpTest::inf<int64_t>() {
90   return 0x7FFFFFFFFFFFFFFFl;
91 }
92 
93 template <>
AbsTestHelper()94 void UnaryOpTest::AbsTestHelper<complex64>() {
95   XlaBuilder builder(TestName());
96   auto arg = ConstantR1<complex64>(&builder, {{-2, 0},
97                                               {0, 25},
98                                               {0, 0},
99                                               {-0.3f, 0.4f},
100                                               {0, inf<float>()},
101                                               {-inf<float>(), 0}});
102   Abs(arg);
103 
104   Literal expected =
105       LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
106   ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
107 }
108 
109 template <>
SignTestHelper()110 void UnaryOpTest::SignTestHelper<complex64>() {
111   XlaBuilder builder(TestName());
112   auto arg = ConstantR1<complex64>(
113       &builder,
114       {{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
115   Sign(arg);
116 
117   Literal expected = LiteralUtil::CreateR1<complex64>(
118       {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
119   ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
120 }
121 
122 template <>
SignAbsTestHelper()123 void UnaryOpTest::SignAbsTestHelper<complex64>() {
124   XlaBuilder builder(TestName());
125   auto arg =
126       ConstantR1<complex64>(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
127   auto sign = Sign(arg);
128   auto abs = Abs(arg);
129   Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
130 
131   Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
132   ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
133 }
134 
XLA_TEST_F(UnaryOpTest,AbsTestR1Size0)135 XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
136   AbsSize0TestHelper<int>();
137   AbsSize0TestHelper<float>();
138   AbsSize0TestHelper<complex64>();
139 }
140 
XLA_TEST_F(UnaryOpTest,AbsTestR1)141 XLA_TEST_F(UnaryOpTest, AbsTestR1) {
142   AbsTestHelper<int>();
143   AbsTestHelper<float>();
144   AbsTestHelper<complex64>();
145 }
146 
XLA_TEST_F(UnaryOpTest,AbsTestR0)147 XLA_TEST_F(UnaryOpTest, AbsTestR0) {
148   XlaBuilder builder(TestName());
149   auto argi = ConstantR0<int>(&builder, -5);
150   auto absi = Abs(argi);
151   auto argf = ConstantR0<float>(&builder, -3.0f);
152   auto absf = Abs(argf);
153   auto argf0 = ConstantR0<float>(&builder, -0.0f);
154   auto absf0 = Abs(argf0);
155   auto argc = ConstantR0<complex64>(&builder, {-0.3f, 0.4f});
156   auto absc = Abs(argc);
157   Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32)));
158 
159   ComputeAndCompareR0<float>(&builder, 8.5f, {});
160 }
161 
XLA_TEST_F(UnaryOpTest,SignTestR0)162 XLA_TEST_F(UnaryOpTest, SignTestR0) {
163   XlaBuilder builder(TestName());
164   auto argi = ConstantR0<int>(&builder, -5);
165   auto sgni = Sign(argi);  // -1
166   auto argf = ConstantR0<float>(&builder, -4.0f);
167   auto sgnf = Sign(argf);  // -1
168   auto argf0 = ConstantR0<float>(&builder, -0.0f);
169   auto sgnf0 = Sign(argf0);  // 0
170   auto argc = ConstantR0<complex64>(&builder, {-.3, .4});
171   auto sgnc = Sign(argc);  // (-.6, .8)
172   Add(sgnc, ConvertElementType(
173                 Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
174 
175   Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
176   ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
177 }
178 
XLA_TEST_F(UnaryOpTest,SignTestR1)179 XLA_TEST_F(UnaryOpTest, SignTestR1) {
180   SignTestHelper<int>();
181   SignTestHelper<int64_t>();
182   SignTestHelper<float>();
183   SignTestHelper<complex64>();
184 }
185 
XLA_TEST_F(UnaryOpTest,SignAbsTestR1)186 XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
187   SignAbsTestHelper<int>();
188   SignAbsTestHelper<float>();
189   SignAbsTestHelper<complex64>();
190 }
191 
XLA_TEST_F(UnaryOpTest,SignAbsTestR2)192 XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
193   XlaBuilder builder(TestName());
194   auto arg = ConstantR2<float>(&builder, {{1.0, -2.0}, {-3.0, 4.0}});
195   auto sign = Sign(arg);
196   auto abs = Abs(arg);
197   Sub(Mul(sign, abs), arg);
198 
199   ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
200 }
201 
XLA_TEST_F(UnaryOpTest,ConvertElementTypePredToS32)202 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) {
203   XlaBuilder builder(TestName());
204   auto lhs = ConstantR1<int32_t>(&builder, {0, 1});
205   auto rhs = ConstantR1<int32_t>(&builder, {1, 1});
206   ConvertElementType(Eq(lhs, rhs), S32);
207 
208   ComputeAndCompareR1<int32_t>(&builder, {0, 1}, {});
209 }
210 
XLA_TEST_F(UnaryOpTest,ConvertElementTypePredToF32)211 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) {
212   XlaBuilder builder(TestName());
213   auto lhs = ConstantR1<int32_t>(&builder, {0, 1});
214   auto rhs = ConstantR1<int32_t>(&builder, {1, 1});
215   ConvertElementType(Eq(lhs, rhs), F32);
216 
217   ComputeAndCompareR1<float>(&builder, {0.0, 1.0}, {});
218 }
219 
220 }  // namespace
221 }  // namespace xla
222