1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/tf2xla/xla_expression.h"
20 #include "tensorflow/compiler/tf2xla/xla_resource.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31
32 namespace tensorflow {
33 namespace {
34
35 class XlaExpressionTest : public ::testing::Test {
36 protected:
SetUp()37 void SetUp() override {
38 client_ = xla::ClientLibrary::LocalClientOrDie();
39 builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
40 constant_ = test::AsScalar<int32>(42);
41 op_ = xla::ConstantR0<int32>(builder_.get(), 7);
42 non_constant_op_ = xla::Parameter(
43 builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
44 resource_ = absl::make_unique<XlaResource>(
45 XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
46 DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
47 /*tensor_array_gradients=*/std::set<string>(),
48 /*tensor_array_multiple_writes_aggregate=*/false);
49 }
50
51 xla::Client* client_;
52 std::unique_ptr<xla::XlaBuilder> builder_;
53 Tensor constant_;
54 xla::XlaOp op_;
55 xla::XlaOp non_constant_op_;
56 std::unique_ptr<XlaResource> resource_;
57 };
58
TEST_F(XlaExpressionTest,Kind)59 TEST_F(XlaExpressionTest, Kind) {
60 EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
61 EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
62 EXPECT_TRUE(XlaExpression::Kind::kConstant ==
63 XlaExpression::Constant(constant_).kind());
64 EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
65 XlaExpression::XlaOp(op_, DT_INT32).kind());
66 EXPECT_TRUE(XlaExpression::Kind::kResource ==
67 XlaExpression::Resource(resource_.get()).kind());
68 }
69
TEST_F(XlaExpressionTest,HumanString)70 TEST_F(XlaExpressionTest, HumanString) {
71 EXPECT_EQ("invalid", XlaExpression().HumanString());
72 EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
73 EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
74 EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
75 EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
76 }
77
TEST_F(XlaExpressionTest,AsXlaOp)78 TEST_F(XlaExpressionTest, AsXlaOp) {
79 xla::XlaOp op_as_op =
80 XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
81 EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
82
83 xla::XlaOp const_as_op =
84 XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
85 TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
86 builder_->BuildConstantSubGraph(const_as_op));
87 TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
88 client_->ComputeConstant(computation));
89 EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
90 value));
91 }
92
TEST_F(XlaExpressionTest,GetShape)93 TEST_F(XlaExpressionTest, GetShape) {
94 EXPECT_FALSE(XlaExpression().GetShape().ok());
95 EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
96
97 TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
98 XlaExpression::Resource(resource_.get()).GetShape());
99 EXPECT_EQ(TensorShape({}), resource_shape);
100
101 TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
102 XlaExpression::XlaOp(op_, DT_INT32).GetShape());
103 EXPECT_EQ(TensorShape({}), op_shape);
104
105 TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
106 XlaExpression::Constant(constant_).GetShape());
107 EXPECT_EQ(TensorShape({}), constant_shape);
108 }
109
TEST_F(XlaExpressionTest,ResolveConstant)110 TEST_F(XlaExpressionTest, ResolveConstant) {
111 EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
112 EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
113
114 EXPECT_FALSE(XlaExpression::Resource(resource_.get())
115 .ResolveConstant(client_)
116 ->has_value());
117
118 TF_ASSERT_OK_AND_ASSIGN(
119 absl::optional<Tensor> op_constant,
120 XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
121 ASSERT_TRUE(op_constant.has_value());
122 test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
123
124 TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
125 XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
126 .ResolveConstant(client_));
127 EXPECT_FALSE(op_nonconstant.has_value());
128
129 TF_ASSERT_OK_AND_ASSIGN(
130 absl::optional<Tensor> constant_constant,
131 XlaExpression::Constant(constant_).ResolveConstant(client_));
132 ASSERT_TRUE(constant_constant.has_value());
133 test::ExpectTensorEqual<int32>(constant_, *constant_constant);
134 }
135
TEST_F(XlaExpressionTest,ResolveConstantOnResource)136 TEST_F(XlaExpressionTest, ResolveConstantOnResource) {
137 XlaExpression constant_resource =
138 XlaExpression::ConstantResource(constant_, resource_.get());
139 EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok());
140 EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok());
141 LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten();
142 xla::StatusOr<absl::optional<Tensor>> resolved_constant =
143 constant_resource.ResolveConstant(client_);
144 EXPECT_TRUE(resolved_constant.ok());
145 EXPECT_FALSE(resolved_constant->has_value());
146 }
147
148 } // namespace
149 } // namespace tensorflow
150