• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 #include <vector>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/tests/test_utils.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 
37 namespace xla {
38 namespace {
39 
40 // An enumerator for the client types that we want to iterate over in
41 // the various tests.
42 enum class ClientType { kLocal, kCompileOnly };
43 ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
44 
45 class ComputeConstantTest : public ::testing::Test {
46  public:
ComputeConstantTest(se::Platform * platform=nullptr)47   explicit ComputeConstantTest(se::Platform* platform = nullptr)
48       : platform_(platform) {}
49 
TestName() const50   std::string TestName() const {
51     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
52   }
53 
ClientOrDie(se::Platform * platform,ClientType client_type)54   Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
55     if (client_type == ClientType::kLocal) {
56       StatusOr<Client*> result =
57           ClientLibrary::GetOrCreateLocalClient(platform);
58       TF_CHECK_OK(result.status())
59           << "could not create LocalClient for testing";
60       return result.ValueOrDie();
61     } else if (client_type == ClientType::kCompileOnly) {
62       StatusOr<Client*> result =
63           ClientLibrary::GetOrCreateCompileOnlyClient(platform);
64       TF_CHECK_OK(result.status())
65           << "could not create CompileOnlyClient for testing";
66       return result.ValueOrDie();
67     }
68     LOG(FATAL) << "invalid client_type value";
69   }
70 
ComputeConstantLiteral(Client * client,const XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)71   StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp operand,
72                                            XlaBuilder* builder,
73                                            Layout* output_layout = nullptr) {
74     TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
75     TF_ASSIGN_OR_RETURN(auto computed,
76                         client->ComputeConstant(subgraph, output_layout));
77     return std::move(computed);
78   }
79 
80   template <class Scalar>
ComputeConstantScalar(Client * client,const XlaOp operand,XlaBuilder * builder)81   StatusOr<Scalar> ComputeConstantScalar(Client* client, const XlaOp operand,
82                                          XlaBuilder* builder) {
83     TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
84                                                              builder, nullptr));
85     return literal.Get<Scalar>({});
86   }
87 
IsConstant(const XlaOp operand,XlaBuilder * builder)88   bool IsConstant(const XlaOp operand, XlaBuilder* builder) {
89     StatusOr<bool> result = builder->IsConstant(operand);
90     EXPECT_TRUE(result.ok()) << result.status();
91     return result.ok() ? result.ValueOrDie() : false;
92   }
93 
94   se::Platform* platform_;
95 };
96 
TEST_F(ComputeConstantTest,ScalarInt32Literal)97 TEST_F(ComputeConstantTest, ScalarInt32Literal) {
98   for (ClientType client_type : client_types) {
99     Client* client = ClientOrDie(platform_, client_type);
100     XlaBuilder b(TestName());
101     auto computation = ConstantR0<int32_t>(&b, 42);
102     EXPECT_TRUE(IsConstant(computation, &b));
103 
104     auto value = ComputeConstantScalar<int32_t>(client, computation, &b);
105     ASSERT_TRUE(value.ok()) << value.status();
106     EXPECT_EQ(value.ValueOrDie(), 42);
107   }
108 }
109 
TEST_F(ComputeConstantTest,ScalarFloatAdd)110 TEST_F(ComputeConstantTest, ScalarFloatAdd) {
111   for (ClientType client_type : client_types) {
112     Client* client = ClientOrDie(platform_, client_type);
113     XlaBuilder b(TestName());
114     auto computation =
115         Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f));
116     EXPECT_TRUE(IsConstant(computation, &b));
117 
118     auto value = ComputeConstantScalar<float>(client, computation, &b);
119     ASSERT_TRUE(value.ok()) << value.status();
120     EXPECT_EQ(value.ValueOrDie(), 44.0f);
121   }
122 }
123 
TEST_F(ComputeConstantTest,ScalarRng)124 TEST_F(ComputeConstantTest, ScalarRng) {
125   for (ClientType client_type : client_types) {
126     Client* client = ClientOrDie(platform_, client_type);
127     XlaBuilder b(TestName());
128     auto computation =
129         RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f),
130                    ShapeUtil::MakeShape(F32, {}));
131     EXPECT_FALSE(IsConstant(computation, &b));
132 
133     auto value = ComputeConstantScalar<float>(client, computation, &b);
134     ASSERT_FALSE(value.ok())
135         << "computing a RNG value should not be considered a constant";
136   }
137 }
138 
TEST_F(ComputeConstantTest,DirectParamMissing)139 TEST_F(ComputeConstantTest, DirectParamMissing) {
140   for (ClientType client_type : client_types) {
141     Client* client = ClientOrDie(platform_, client_type);
142     XlaBuilder b(TestName());
143     auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param");
144     EXPECT_FALSE(IsConstant(computation, &b));
145 
146     auto value = ComputeConstantScalar<float>(client, computation, &b);
147     EXPECT_TRUE(
148         absl::StrContains(value.status().ToString(), "depends on a parameter"))
149         << value.status();
150   }
151 }
152 
TEST_F(ComputeConstantTest,GetDimensionSize)153 TEST_F(ComputeConstantTest, GetDimensionSize) {
154   for (ClientType client_type : client_types) {
155     Client* client = ClientOrDie(platform_, client_type);
156     XlaBuilder b(TestName());
157     auto add =
158         Add(ConstantR1<float>(&b, {1.0f}), ConstantR1<float>(&b, {1.0f}));
159     auto get_dimension_size = GetDimensionSize(add, 0);
160     EXPECT_TRUE(IsConstant(get_dimension_size, &b));
161 
162     TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar<int32_t>(
163                                             client, get_dimension_size, &b));
164     EXPECT_EQ(value, 1);
165   }
166 }
167 
TEST_F(ComputeConstantTest,MultipleGetDimensionSize)168 TEST_F(ComputeConstantTest, MultipleGetDimensionSize) {
169   for (ClientType client_type : client_types) {
170     Client* client = ClientOrDie(platform_, client_type);
171     XlaBuilder b(TestName());
172     auto add =
173         Add(ConstantR2<float>(&b, {{1.0f}}), ConstantR2<float>(&b, {{1.0f}}));
174     auto get_dimension_size = GetDimensionSize(add, 0);
175     auto get_dimension_size_2 = GetDimensionSize(add, 0);
176     auto add_2 = Add(get_dimension_size, get_dimension_size_2);
177     EXPECT_TRUE(IsConstant(add_2, &b));
178 
179     TF_ASSERT_OK_AND_ASSIGN(auto value,
180                             ComputeConstantScalar<int32_t>(client, add_2, &b));
181     EXPECT_EQ(value, 2);
182   }
183 }
184 
185 // Test computation of an expression interspersed with param nodes but
186 // the expression does not depend on the param nodes.
TEST_F(ComputeConstantTest,UnrelatedParam)187 TEST_F(ComputeConstantTest, UnrelatedParam) {
188   for (ClientType client_type : client_types) {
189     Client* client = ClientOrDie(platform_, client_type);
190     XlaBuilder b(TestName());
191 
192     auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0");
193     auto constant_4 =
194         Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f));
195     auto not_constant_a = Add(constant_4, param_a);
196 
197     auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1");
198     auto constant_9 =
199         Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f));
200     auto not_constant_b = Add(param_b, constant_9);
201 
202     auto constant_13 = Add(constant_4, constant_9);
203     Add(not_constant_b, Add(constant_13, not_constant_a));
204 
205     EXPECT_TRUE(IsConstant(constant_13, &b));
206 
207     TF_ASSERT_OK_AND_ASSIGN(
208         auto value, ComputeConstantScalar<float>(client, constant_13, &b));
209     EXPECT_EQ(value, 13.0f);
210   }
211 }
212 
TEST_F(ComputeConstantTest,NonScalarAdd)213 TEST_F(ComputeConstantTest, NonScalarAdd) {
214   for (ClientType client_type : client_types) {
215     Client* client = ClientOrDie(platform_, client_type);
216     XlaBuilder b(TestName());
217 
218     auto computation =
219         Add(ConstantR1<int32_t>(&b, {1, 2}), ConstantR1<int32_t>(&b, {3, 4}));
220     EXPECT_TRUE(IsConstant(computation, &b));
221 
222     TF_ASSERT_OK_AND_ASSIGN(auto computed,
223                             ComputeConstantLiteral(client, computation, &b));
224     Literal expected_literal = LiteralUtil::CreateR1<int32_t>({4, 6});
225     EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
226   }
227 }
228 
TEST_F(ComputeConstantTest,IntegerDivide)229 TEST_F(ComputeConstantTest, IntegerDivide) {
230   for (ClientType client_type : client_types) {
231     Client* client = ClientOrDie(platform_, client_type);
232     XlaBuilder b(TestName());
233     auto computation =
234         Div(ConstantR0<int32_t>(&b, 15), ConstantR0<int32_t>(&b, 3));
235     EXPECT_TRUE(IsConstant(computation, &b));
236 
237     TF_ASSERT_OK_AND_ASSIGN(auto computed,
238                             ComputeConstantLiteral(client, computation, &b));
239     Literal expected_literal = LiteralUtil::CreateR0<int32_t>(5);
240     EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
241   }
242 }
243 
XLA_TEST_F(ComputeConstantTest,Layout)244 XLA_TEST_F(ComputeConstantTest, Layout) {
245   for (ClientType client_type : client_types) {
246     Client* client = ClientOrDie(platform_, client_type);
247     XlaBuilder b(TestName());
248 
249     std::vector<std::vector<int64_t>> layouts = {{0, 1}, {1, 0}};
250     for (const std::vector<int64_t>& layout : layouts) {
251       auto layout_proto = LayoutUtil::MakeLayout(layout);
252       TF_ASSERT_OK_AND_ASSIGN(
253           auto computed, ComputeConstantLiteral(
254                              client,
255                              Add(ConstantR2<int32_t>(&b, {{1, 2}, {3, 4}}),
256                                  ConstantR2<int32_t>(&b, {{10, 20}, {30, 40}})),
257                              &b, &layout_proto));
258 
259       Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32_t>(
260           {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
261       ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
262           expected_literal.shape(), computed.shape()));
263       EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
264     }
265   }
266 }
267 
268 }  // namespace
269 }  // namespace xla
270