• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_kernel_creator.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/common_runtime/device_factory.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/public/session_options.h"
28 #include "tensorflow/core/public/version.h"
29 #include "tensorflow/core/util/ptr_util.h"
30 
31 namespace tensorflow {
32 
ToNodeProperties(const string & text)33 std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
34   NodeDef node_def;
35   DataTypeVector dummy;
36   EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
37   return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
38                                           dummy);
39 }
40 
41 // Create a FunctionDef that takes one resource and one regular param
XTimesY()42 FunctionDef XTimesY() {
43   return FunctionDefHelper::Define(
44       // Name
45       "XTimesY",
46       // Args
47       {"x: float", "y: resource"},
48       // Return values
49       {"z: float"},
50       // Attr def
51       {},
52       // Nodes
53       {
54           {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
55           {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
56       });
57 }
58 
59 class XlaKernelCreatorTest : public ::testing::Test {
60  protected:
Init(const std::vector<FunctionDef> & flib)61   void Init(const std::vector<FunctionDef>& flib) {
62     SessionOptions options;
63     auto* device_count = options.config.mutable_device_count();
64     device_count->insert({"CPU", 1});
65     std::vector<std::unique_ptr<Device>> devices;
66     TF_CHECK_OK(DeviceFactory::AddDevices(
67         options, "/job:localhost/replica:0/task:0", &devices));
68 
69     FunctionDefLibrary proto;
70     for (const auto& fdef : flib) {
71       *(proto.add_function()) = fdef;
72     }
73     lib_def_ = std::make_unique<FunctionLibraryDefinition>(
74         OpRegistry::Global(), proto);
75     OptimizerOptions opts;
76     device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
77     pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
78         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
79         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
80         /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
81     flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
82   }
83 
84   FunctionLibraryRuntime* flr_;
85   std::unique_ptr<DeviceMgr> device_mgr_;
86   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
87   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
88 
89   std::unique_ptr<OpKernel> kernel_;
90 };
91 
BoolAttr(bool b)92 AttrValue BoolAttr(bool b) {
93   AttrValue v;
94   v.set_b(b);
95   return v;
96 }
97 
TEST_F(XlaKernelCreatorTest,OneFloatOneResourceArgument)98 TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
99   FunctionDef fdef = XTimesY();
100   (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
101   Init({fdef});
102   XlaKernelCreator xla_kernel_creator;
103   auto callsite =
104       ToNodeProperties(R"pb(
105         name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
106       )pb");
107   (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
108 
109   // Note: need to set attribute on the created node.
110   Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
111   ASSERT_TRUE(status.ok()) << status.ToString();
112 
113   EXPECT_EQ("XTimesY", kernel_->name());
114   EXPECT_EQ("XTimesY", kernel_->type_string());
115 
116   EXPECT_EQ(2, kernel_->num_inputs());
117   EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
118   EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
119   EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
120   EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
121 
122   EXPECT_EQ(1, kernel_->num_outputs());
123   EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
124   EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
125 }
126 
TEST_F(XlaKernelCreatorTest,FailsIfXlaCompileAttrNotSet)127 TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
128   FunctionDef fdef = XTimesY();
129   Init({fdef});
130   XlaKernelCreator xla_kernel_creator;
131 
132   Status status =
133       xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
134                                         name: 'XTimesY'
135                                         op: 'XTimesY'
136                                         input: 'a'
137                                         input: 'b'
138                                       )proto"),
139                                       &kernel_);
140   EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
141 }
142 
TEST_F(XlaKernelCreatorTest,FailsIfXlaCompileAttrIsSetToFalse)143 TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
144   FunctionDef fdef = XTimesY();
145   (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
146   Init({fdef});
147   XlaKernelCreator xla_kernel_creator;
148 
149   Status status =
150       xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
151                                         name: 'XTimesY'
152                                         op: 'XTimesY'
153                                         input: 'a'
154                                         input: 'b'
155                                       )proto"),
156                                       &kernel_);
157   EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
158 }
159 
160 }  // namespace tensorflow
161