1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_kernel_creator.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/common_runtime/device_factory.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/framework/function_testlib.h"
22 #include "tensorflow/core/framework/node_def_builder.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/public/session_options.h"
28 #include "tensorflow/core/public/version.h"
29 #include "tensorflow/core/util/ptr_util.h"
30
31 namespace tensorflow {
32
ToNodeProperties(const string & text)33 std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
34 NodeDef node_def;
35 DataTypeVector dummy;
36 EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
37 return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
38 dummy);
39 }
40
41 // Create a FunctionDef that takes one resource and one regular param
XTimesY()42 FunctionDef XTimesY() {
43 return FunctionDefHelper::Define(
44 // Name
45 "XTimesY",
46 // Args
47 {"x: float", "y: resource"},
48 // Return values
49 {"z: float"},
50 // Attr def
51 {},
52 // Nodes
53 {
54 {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
55 {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
56 });
57 }
58
59 class XlaKernelCreatorTest : public ::testing::Test {
60 protected:
Init(const std::vector<FunctionDef> & flib)61 void Init(const std::vector<FunctionDef>& flib) {
62 SessionOptions options;
63 auto* device_count = options.config.mutable_device_count();
64 device_count->insert({"CPU", 1});
65 std::vector<std::unique_ptr<Device>> devices;
66 TF_CHECK_OK(DeviceFactory::AddDevices(
67 options, "/job:localhost/replica:0/task:0", &devices));
68
69 FunctionDefLibrary proto;
70 for (const auto& fdef : flib) {
71 *(proto.add_function()) = fdef;
72 }
73 lib_def_ = std::make_unique<FunctionLibraryDefinition>(
74 OpRegistry::Global(), proto);
75 OptimizerOptions opts;
76 device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
77 pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
78 device_mgr_.get(), Env::Default(), /*config=*/nullptr,
79 TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
80 /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
81 flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
82 }
83
84 FunctionLibraryRuntime* flr_;
85 std::unique_ptr<DeviceMgr> device_mgr_;
86 std::unique_ptr<FunctionLibraryDefinition> lib_def_;
87 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
88
89 std::unique_ptr<OpKernel> kernel_;
90 };
91
BoolAttr(bool b)92 AttrValue BoolAttr(bool b) {
93 AttrValue v;
94 v.set_b(b);
95 return v;
96 }
97
TEST_F(XlaKernelCreatorTest,OneFloatOneResourceArgument)98 TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
99 FunctionDef fdef = XTimesY();
100 (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
101 Init({fdef});
102 XlaKernelCreator xla_kernel_creator;
103 auto callsite =
104 ToNodeProperties(R"pb(
105 name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
106 )pb");
107 (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
108
109 // Note: need to set attribute on the created node.
110 Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
111 ASSERT_TRUE(status.ok()) << status.ToString();
112
113 EXPECT_EQ("XTimesY", kernel_->name());
114 EXPECT_EQ("XTimesY", kernel_->type_string());
115
116 EXPECT_EQ(2, kernel_->num_inputs());
117 EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
118 EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
119 EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
120 EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
121
122 EXPECT_EQ(1, kernel_->num_outputs());
123 EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
124 EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
125 }
126
TEST_F(XlaKernelCreatorTest,FailsIfXlaCompileAttrNotSet)127 TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
128 FunctionDef fdef = XTimesY();
129 Init({fdef});
130 XlaKernelCreator xla_kernel_creator;
131
132 Status status =
133 xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
134 name: 'XTimesY'
135 op: 'XTimesY'
136 input: 'a'
137 input: 'b'
138 )proto"),
139 &kernel_);
140 EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
141 }
142
TEST_F(XlaKernelCreatorTest,FailsIfXlaCompileAttrIsSetToFalse)143 TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
144 FunctionDef fdef = XTimesY();
145 (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
146 Init({fdef});
147 XlaKernelCreator xla_kernel_creator;
148
149 Status status =
150 xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
151 name: 'XTimesY'
152 op: 'XTimesY'
153 input: 'a'
154 input: 'b'
155 )proto"),
156 &kernel_);
157 EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
158 }
159
160 } // namespace tensorflow
161