1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace xla {
29 namespace {
30
31 class UnaryOpTest : public ClientLibraryTestBase {
32 protected:
33 template <typename T>
inf()34 T inf() {
35 return std::numeric_limits<T>::infinity();
36 }
37 template <typename T>
AbsSize0TestHelper()38 void AbsSize0TestHelper() {
39 XlaBuilder builder(TestName());
40 auto arg = ConstantR1<T>(&builder, {});
41 Abs(arg);
42
43 if (primitive_util::NativeToPrimitiveType<T>() == C64) {
44 ComputeAndCompareR1<float>(&builder, {}, {});
45 } else {
46 ComputeAndCompareR1<T>(&builder, {}, {});
47 }
48 }
49
50 template <typename T>
AbsTestHelper()51 void AbsTestHelper() {
52 XlaBuilder builder(TestName());
53 auto arg = ConstantR1<T>(&builder, {-2, 25, 0, -123, inf<T>(), -inf<T>()});
54 Abs(arg);
55
56 ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
57 }
58
59 template <typename T>
SignTestHelper()60 void SignTestHelper() {
61 XlaBuilder builder(TestName());
62 auto arg = ConstantR1<T>(
63 &builder, {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
64 Sign(arg);
65
66 ComputeAndCompareR1<T>(
67 &builder,
68 {-1, 1, static_cast<T>(+0.0), static_cast<T>(-0.0), -1, 1, -1}, {});
69 }
70
71 template <typename T>
SignAbsTestHelper()72 void SignAbsTestHelper() {
73 XlaBuilder builder(TestName());
74 auto arg = ConstantR1<T>(&builder, {-2, 25, 0, -123});
75 auto sign = Sign(arg);
76 auto abs = Abs(arg);
77 Sub(Mul(sign, abs), arg);
78
79 ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {});
80 }
81 };
82
83 template <>
inf()84 int UnaryOpTest::inf<int>() {
85 return 2147483647;
86 }
87
88 template <>
inf()89 int64_t UnaryOpTest::inf<int64_t>() {
90 return 0x7FFFFFFFFFFFFFFFl;
91 }
92
93 template <>
AbsTestHelper()94 void UnaryOpTest::AbsTestHelper<complex64>() {
95 XlaBuilder builder(TestName());
96 auto arg = ConstantR1<complex64>(&builder, {{-2, 0},
97 {0, 25},
98 {0, 0},
99 {-0.3f, 0.4f},
100 {0, inf<float>()},
101 {-inf<float>(), 0}});
102 Abs(arg);
103
104 Literal expected =
105 LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
106 ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
107 }
108
109 template <>
SignTestHelper()110 void UnaryOpTest::SignTestHelper<complex64>() {
111 XlaBuilder builder(TestName());
112 auto arg = ConstantR1<complex64>(
113 &builder,
114 {{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
115 Sign(arg);
116
117 Literal expected = LiteralUtil::CreateR1<complex64>(
118 {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
119 ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
120 }
121
122 template <>
SignAbsTestHelper()123 void UnaryOpTest::SignAbsTestHelper<complex64>() {
124 XlaBuilder builder(TestName());
125 auto arg =
126 ConstantR1<complex64>(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
127 auto sign = Sign(arg);
128 auto abs = Abs(arg);
129 Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
130
131 Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
132 ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
133 }
134
XLA_TEST_F(UnaryOpTest,AbsTestR1Size0)135 XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
136 AbsSize0TestHelper<int>();
137 AbsSize0TestHelper<float>();
138 AbsSize0TestHelper<complex64>();
139 }
140
XLA_TEST_F(UnaryOpTest,AbsTestR1)141 XLA_TEST_F(UnaryOpTest, AbsTestR1) {
142 AbsTestHelper<int>();
143 AbsTestHelper<float>();
144 AbsTestHelper<complex64>();
145 }
146
XLA_TEST_F(UnaryOpTest,AbsTestR0)147 XLA_TEST_F(UnaryOpTest, AbsTestR0) {
148 XlaBuilder builder(TestName());
149 auto argi = ConstantR0<int>(&builder, -5);
150 auto absi = Abs(argi);
151 auto argf = ConstantR0<float>(&builder, -3.0f);
152 auto absf = Abs(argf);
153 auto argf0 = ConstantR0<float>(&builder, -0.0f);
154 auto absf0 = Abs(argf0);
155 auto argc = ConstantR0<complex64>(&builder, {-0.3f, 0.4f});
156 auto absc = Abs(argc);
157 Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32)));
158
159 ComputeAndCompareR0<float>(&builder, 8.5f, {});
160 }
161
XLA_TEST_F(UnaryOpTest,SignTestR0)162 XLA_TEST_F(UnaryOpTest, SignTestR0) {
163 XlaBuilder builder(TestName());
164 auto argi = ConstantR0<int>(&builder, -5);
165 auto sgni = Sign(argi); // -1
166 auto argf = ConstantR0<float>(&builder, -4.0f);
167 auto sgnf = Sign(argf); // -1
168 auto argf0 = ConstantR0<float>(&builder, -0.0f);
169 auto sgnf0 = Sign(argf0); // 0
170 auto argc = ConstantR0<complex64>(&builder, {-.3, .4});
171 auto sgnc = Sign(argc); // (-.6, .8)
172 Add(sgnc, ConvertElementType(
173 Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
174
175 Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
176 ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
177 }
178
XLA_TEST_F(UnaryOpTest,SignTestR1)179 XLA_TEST_F(UnaryOpTest, SignTestR1) {
180 SignTestHelper<int>();
181 SignTestHelper<int64_t>();
182 SignTestHelper<float>();
183 SignTestHelper<complex64>();
184 }
185
XLA_TEST_F(UnaryOpTest,SignAbsTestR1)186 XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
187 SignAbsTestHelper<int>();
188 SignAbsTestHelper<float>();
189 SignAbsTestHelper<complex64>();
190 }
191
XLA_TEST_F(UnaryOpTest,SignAbsTestR2)192 XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
193 XlaBuilder builder(TestName());
194 auto arg = ConstantR2<float>(&builder, {{1.0, -2.0}, {-3.0, 4.0}});
195 auto sign = Sign(arg);
196 auto abs = Abs(arg);
197 Sub(Mul(sign, abs), arg);
198
199 ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
200 }
201
XLA_TEST_F(UnaryOpTest,ConvertElementTypePredToS32)202 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) {
203 XlaBuilder builder(TestName());
204 auto lhs = ConstantR1<int32_t>(&builder, {0, 1});
205 auto rhs = ConstantR1<int32_t>(&builder, {1, 1});
206 ConvertElementType(Eq(lhs, rhs), S32);
207
208 ComputeAndCompareR1<int32_t>(&builder, {0, 1}, {});
209 }
210
XLA_TEST_F(UnaryOpTest,ConvertElementTypePredToF32)211 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) {
212 XlaBuilder builder(TestName());
213 auto lhs = ConstantR1<int32_t>(&builder, {0, 1});
214 auto rhs = ConstantR1<int32_t>(&builder, {1, 1});
215 ConvertElementType(Eq(lhs, rhs), F32);
216
217 ComputeAndCompareR1<float>(&builder, {0.0, 1.0}, {});
218 }
219
220 } // namespace
221 } // namespace xla
222