• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/tf2xla/xla_expression.h"
20 #include "tensorflow/compiler/tf2xla/xla_resource.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 class XlaExpressionTest : public ::testing::Test {
36  protected:
SetUp()37   void SetUp() override {
38     client_ = xla::ClientLibrary::LocalClientOrDie();
39     builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
40     constant_ = test::AsScalar<int32>(42);
41     op_ = xla::ConstantR0<int32>(builder_.get(), 7);
42     non_constant_op_ = xla::Parameter(
43         builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
44     resource_ = absl::make_unique<XlaResource>(
45         XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
46         DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
47         /*tensor_array_gradients=*/std::set<string>(),
48         /*tensor_array_multiple_writes_aggregate=*/false);
49   }
50 
51   xla::Client* client_;
52   std::unique_ptr<xla::XlaBuilder> builder_;
53   Tensor constant_;
54   xla::XlaOp op_;
55   xla::XlaOp non_constant_op_;
56   std::unique_ptr<XlaResource> resource_;
57 };
58 
TEST_F(XlaExpressionTest,Kind)59 TEST_F(XlaExpressionTest, Kind) {
60   EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
61   EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
62   EXPECT_TRUE(XlaExpression::Kind::kConstant ==
63               XlaExpression::Constant(constant_).kind());
64   EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
65               XlaExpression::XlaOp(op_, DT_INT32).kind());
66   EXPECT_TRUE(XlaExpression::Kind::kResource ==
67               XlaExpression::Resource(resource_.get()).kind());
68 }
69 
TEST_F(XlaExpressionTest,HumanString)70 TEST_F(XlaExpressionTest, HumanString) {
71   EXPECT_EQ("invalid", XlaExpression().HumanString());
72   EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
73   EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
74   EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
75   EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
76 }
77 
TEST_F(XlaExpressionTest,AsXlaOp)78 TEST_F(XlaExpressionTest, AsXlaOp) {
79   xla::XlaOp op_as_op =
80       XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
81   EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
82 
83   xla::XlaOp const_as_op =
84       XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
85   TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
86                           builder_->BuildConstantSubGraph(const_as_op));
87   TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
88                           client_->ComputeConstant(computation));
89   EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
90                                           value));
91 }
92 
TEST_F(XlaExpressionTest,GetShape)93 TEST_F(XlaExpressionTest, GetShape) {
94   EXPECT_FALSE(XlaExpression().GetShape().ok());
95   EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
96 
97   TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
98                           XlaExpression::Resource(resource_.get()).GetShape());
99   EXPECT_EQ(TensorShape({}), resource_shape);
100 
101   TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
102                           XlaExpression::XlaOp(op_, DT_INT32).GetShape());
103   EXPECT_EQ(TensorShape({}), op_shape);
104 
105   TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
106                           XlaExpression::Constant(constant_).GetShape());
107   EXPECT_EQ(TensorShape({}), constant_shape);
108 }
109 
TEST_F(XlaExpressionTest,ResolveConstant)110 TEST_F(XlaExpressionTest, ResolveConstant) {
111   EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
112   EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
113 
114   EXPECT_FALSE(XlaExpression::Resource(resource_.get())
115                    .ResolveConstant(client_)
116                    ->has_value());
117 
118   TF_ASSERT_OK_AND_ASSIGN(
119       absl::optional<Tensor> op_constant,
120       XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
121   ASSERT_TRUE(op_constant.has_value());
122   test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
123 
124   TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
125                           XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
126                               .ResolveConstant(client_));
127   EXPECT_FALSE(op_nonconstant.has_value());
128 
129   TF_ASSERT_OK_AND_ASSIGN(
130       absl::optional<Tensor> constant_constant,
131       XlaExpression::Constant(constant_).ResolveConstant(client_));
132   ASSERT_TRUE(constant_constant.has_value());
133   test::ExpectTensorEqual<int32>(constant_, *constant_constant);
134 }
135 
TEST_F(XlaExpressionTest,ResolveConstantOnResource)136 TEST_F(XlaExpressionTest, ResolveConstantOnResource) {
137   XlaExpression constant_resource =
138       XlaExpression::ConstantResource(constant_, resource_.get());
139   EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok());
140   EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok());
141   LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten();
142   xla::StatusOr<absl::optional<Tensor>> resolved_constant =
143       constant_resource.ResolveConstant(client_);
144   EXPECT_TRUE(resolved_constant.ok());
145   EXPECT_FALSE(resolved_constant->has_value());
146 }
147 
148 }  // namespace
149 }  // namespace tensorflow
150