1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18 #include <vector>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/tests/test_utils.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36
37 namespace xla {
38 namespace {
39
40 // An enumerator for the client types that we want to iterate over in
41 // the various tests.
42 enum class ClientType { kLocal, kCompileOnly };
43 ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
44
45 class ComputeConstantTest : public ::testing::Test {
46 public:
ComputeConstantTest(se::Platform * platform=nullptr)47 explicit ComputeConstantTest(se::Platform* platform = nullptr)
48 : platform_(platform) {}
49
TestName() const50 std::string TestName() const {
51 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
52 }
53
ClientOrDie(se::Platform * platform,ClientType client_type)54 Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
55 if (client_type == ClientType::kLocal) {
56 StatusOr<Client*> result =
57 ClientLibrary::GetOrCreateLocalClient(platform);
58 TF_CHECK_OK(result.status())
59 << "could not create LocalClient for testing";
60 return result.ValueOrDie();
61 } else if (client_type == ClientType::kCompileOnly) {
62 StatusOr<Client*> result =
63 ClientLibrary::GetOrCreateCompileOnlyClient(platform);
64 TF_CHECK_OK(result.status())
65 << "could not create CompileOnlyClient for testing";
66 return result.ValueOrDie();
67 }
68 LOG(FATAL) << "invalid client_type value";
69 }
70
ComputeConstantLiteral(Client * client,const XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)71 StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp operand,
72 XlaBuilder* builder,
73 Layout* output_layout = nullptr) {
74 TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
75 TF_ASSIGN_OR_RETURN(auto computed,
76 client->ComputeConstant(subgraph, output_layout));
77 return std::move(computed);
78 }
79
80 template <class Scalar>
ComputeConstantScalar(Client * client,const XlaOp operand,XlaBuilder * builder)81 StatusOr<Scalar> ComputeConstantScalar(Client* client, const XlaOp operand,
82 XlaBuilder* builder) {
83 TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
84 builder, nullptr));
85 return literal.Get<Scalar>({});
86 }
87
IsConstant(const XlaOp operand,XlaBuilder * builder)88 bool IsConstant(const XlaOp operand, XlaBuilder* builder) {
89 StatusOr<bool> result = builder->IsConstant(operand);
90 EXPECT_TRUE(result.ok()) << result.status();
91 return result.ok() ? result.ValueOrDie() : false;
92 }
93
94 se::Platform* platform_;
95 };
96
TEST_F(ComputeConstantTest,ScalarInt32Literal)97 TEST_F(ComputeConstantTest, ScalarInt32Literal) {
98 for (ClientType client_type : client_types) {
99 Client* client = ClientOrDie(platform_, client_type);
100 XlaBuilder b(TestName());
101 auto computation = ConstantR0<int32_t>(&b, 42);
102 EXPECT_TRUE(IsConstant(computation, &b));
103
104 auto value = ComputeConstantScalar<int32_t>(client, computation, &b);
105 ASSERT_TRUE(value.ok()) << value.status();
106 EXPECT_EQ(value.ValueOrDie(), 42);
107 }
108 }
109
TEST_F(ComputeConstantTest,ScalarFloatAdd)110 TEST_F(ComputeConstantTest, ScalarFloatAdd) {
111 for (ClientType client_type : client_types) {
112 Client* client = ClientOrDie(platform_, client_type);
113 XlaBuilder b(TestName());
114 auto computation =
115 Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f));
116 EXPECT_TRUE(IsConstant(computation, &b));
117
118 auto value = ComputeConstantScalar<float>(client, computation, &b);
119 ASSERT_TRUE(value.ok()) << value.status();
120 EXPECT_EQ(value.ValueOrDie(), 44.0f);
121 }
122 }
123
TEST_F(ComputeConstantTest,ScalarRng)124 TEST_F(ComputeConstantTest, ScalarRng) {
125 for (ClientType client_type : client_types) {
126 Client* client = ClientOrDie(platform_, client_type);
127 XlaBuilder b(TestName());
128 auto computation =
129 RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f),
130 ShapeUtil::MakeShape(F32, {}));
131 EXPECT_FALSE(IsConstant(computation, &b));
132
133 auto value = ComputeConstantScalar<float>(client, computation, &b);
134 ASSERT_FALSE(value.ok())
135 << "computing a RNG value should not be considered a constant";
136 }
137 }
138
TEST_F(ComputeConstantTest,DirectParamMissing)139 TEST_F(ComputeConstantTest, DirectParamMissing) {
140 for (ClientType client_type : client_types) {
141 Client* client = ClientOrDie(platform_, client_type);
142 XlaBuilder b(TestName());
143 auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param");
144 EXPECT_FALSE(IsConstant(computation, &b));
145
146 auto value = ComputeConstantScalar<float>(client, computation, &b);
147 EXPECT_TRUE(
148 absl::StrContains(value.status().ToString(), "depends on a parameter"))
149 << value.status();
150 }
151 }
152
TEST_F(ComputeConstantTest,GetDimensionSize)153 TEST_F(ComputeConstantTest, GetDimensionSize) {
154 for (ClientType client_type : client_types) {
155 Client* client = ClientOrDie(platform_, client_type);
156 XlaBuilder b(TestName());
157 auto add =
158 Add(ConstantR1<float>(&b, {1.0f}), ConstantR1<float>(&b, {1.0f}));
159 auto get_dimension_size = GetDimensionSize(add, 0);
160 EXPECT_TRUE(IsConstant(get_dimension_size, &b));
161
162 TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar<int32_t>(
163 client, get_dimension_size, &b));
164 EXPECT_EQ(value, 1);
165 }
166 }
167
TEST_F(ComputeConstantTest,MultipleGetDimensionSize)168 TEST_F(ComputeConstantTest, MultipleGetDimensionSize) {
169 for (ClientType client_type : client_types) {
170 Client* client = ClientOrDie(platform_, client_type);
171 XlaBuilder b(TestName());
172 auto add =
173 Add(ConstantR2<float>(&b, {{1.0f}}), ConstantR2<float>(&b, {{1.0f}}));
174 auto get_dimension_size = GetDimensionSize(add, 0);
175 auto get_dimension_size_2 = GetDimensionSize(add, 0);
176 auto add_2 = Add(get_dimension_size, get_dimension_size_2);
177 EXPECT_TRUE(IsConstant(add_2, &b));
178
179 TF_ASSERT_OK_AND_ASSIGN(auto value,
180 ComputeConstantScalar<int32_t>(client, add_2, &b));
181 EXPECT_EQ(value, 2);
182 }
183 }
184
185 // Test computation of an expression interspersed with param nodes but
186 // the expression does not depend on the param nodes.
TEST_F(ComputeConstantTest,UnrelatedParam)187 TEST_F(ComputeConstantTest, UnrelatedParam) {
188 for (ClientType client_type : client_types) {
189 Client* client = ClientOrDie(platform_, client_type);
190 XlaBuilder b(TestName());
191
192 auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0");
193 auto constant_4 =
194 Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f));
195 auto not_constant_a = Add(constant_4, param_a);
196
197 auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1");
198 auto constant_9 =
199 Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f));
200 auto not_constant_b = Add(param_b, constant_9);
201
202 auto constant_13 = Add(constant_4, constant_9);
203 Add(not_constant_b, Add(constant_13, not_constant_a));
204
205 EXPECT_TRUE(IsConstant(constant_13, &b));
206
207 TF_ASSERT_OK_AND_ASSIGN(
208 auto value, ComputeConstantScalar<float>(client, constant_13, &b));
209 EXPECT_EQ(value, 13.0f);
210 }
211 }
212
TEST_F(ComputeConstantTest,NonScalarAdd)213 TEST_F(ComputeConstantTest, NonScalarAdd) {
214 for (ClientType client_type : client_types) {
215 Client* client = ClientOrDie(platform_, client_type);
216 XlaBuilder b(TestName());
217
218 auto computation =
219 Add(ConstantR1<int32_t>(&b, {1, 2}), ConstantR1<int32_t>(&b, {3, 4}));
220 EXPECT_TRUE(IsConstant(computation, &b));
221
222 TF_ASSERT_OK_AND_ASSIGN(auto computed,
223 ComputeConstantLiteral(client, computation, &b));
224 Literal expected_literal = LiteralUtil::CreateR1<int32_t>({4, 6});
225 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
226 }
227 }
228
TEST_F(ComputeConstantTest,IntegerDivide)229 TEST_F(ComputeConstantTest, IntegerDivide) {
230 for (ClientType client_type : client_types) {
231 Client* client = ClientOrDie(platform_, client_type);
232 XlaBuilder b(TestName());
233 auto computation =
234 Div(ConstantR0<int32_t>(&b, 15), ConstantR0<int32_t>(&b, 3));
235 EXPECT_TRUE(IsConstant(computation, &b));
236
237 TF_ASSERT_OK_AND_ASSIGN(auto computed,
238 ComputeConstantLiteral(client, computation, &b));
239 Literal expected_literal = LiteralUtil::CreateR0<int32_t>(5);
240 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
241 }
242 }
243
XLA_TEST_F(ComputeConstantTest,Layout)244 XLA_TEST_F(ComputeConstantTest, Layout) {
245 for (ClientType client_type : client_types) {
246 Client* client = ClientOrDie(platform_, client_type);
247 XlaBuilder b(TestName());
248
249 std::vector<std::vector<int64_t>> layouts = {{0, 1}, {1, 0}};
250 for (const std::vector<int64_t>& layout : layouts) {
251 auto layout_proto = LayoutUtil::MakeLayout(layout);
252 TF_ASSERT_OK_AND_ASSIGN(
253 auto computed, ComputeConstantLiteral(
254 client,
255 Add(ConstantR2<int32_t>(&b, {{1, 2}, {3, 4}}),
256 ConstantR2<int32_t>(&b, {{10, 20}, {30, 40}})),
257 &b, &layout_proto));
258
259 Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32_t>(
260 {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
261 ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
262 expected_literal.shape(), computed.shape()));
263 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
264 }
265 }
266 }
267
268 } // namespace
269 } // namespace xla
270